add tests
This commit is contained in:
parent
f88174c55d
commit
03e1d64aac
|
|
@ -262,3 +262,135 @@ func TestClientDo(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWebSearch(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if !strings.HasSuffix(r.URL.Path, "/api/web_search") {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var req WebSearchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("failed to decode request: %v", err)
|
||||
}
|
||||
|
||||
if req.Query != "what is ollama" {
|
||||
t.Fatalf("unexpected query: %s", req.Query)
|
||||
}
|
||||
|
||||
if req.MaxResults != 3 {
|
||||
t.Fatalf("unexpected max_results: %d", req.MaxResults)
|
||||
}
|
||||
|
||||
resp := WebSearchResponse{
|
||||
Results: []WebSearchResult{{
|
||||
Title: "Ollama",
|
||||
URL: "https://ollama.com",
|
||||
Content: "Cloud models are now available...",
|
||||
}},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
u, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(u, ts.Client())
|
||||
|
||||
resp, err := client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama", MaxResults: 3})
|
||||
if err != nil {
|
||||
t.Fatalf("WebSearch returned error: %v", err)
|
||||
}
|
||||
|
||||
if len(resp.Results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(resp.Results))
|
||||
}
|
||||
|
||||
if resp.Results[0].Title != "Ollama" {
|
||||
t.Fatalf("unexpected title: %s", resp.Results[0].Title)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWebSearchUnauthorized(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
_ = json.NewEncoder(w).Encode(map[string]string{
|
||||
"signin_url": "https://ollama.com/connect",
|
||||
})
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
u, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(u, ts.Client())
|
||||
|
||||
_, err = client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama"})
|
||||
if err == nil {
|
||||
t.Fatal("expected error, got nil")
|
||||
}
|
||||
|
||||
if _, ok := err.(AuthorizationError); !ok {
|
||||
t.Fatalf("expected AuthorizationError, got %T", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClientWebFetch(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Fatalf("expected POST, got %s", r.Method)
|
||||
}
|
||||
if !strings.HasSuffix(r.URL.Path, "/api/web_fetch") {
|
||||
t.Fatalf("unexpected path: %s", r.URL.Path)
|
||||
}
|
||||
|
||||
var req WebFetchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
t.Fatalf("failed to decode request: %v", err)
|
||||
}
|
||||
|
||||
if req.URL != "https://ollama.com" {
|
||||
t.Fatalf("unexpected url: %s", req.URL)
|
||||
}
|
||||
|
||||
resp := WebFetchResponse{
|
||||
Title: "Ollama",
|
||||
Content: "Cloud models are now available...",
|
||||
Links: []string{"https://ollama.com/models"},
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
u, err := url.Parse(ts.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse server URL: %v", err)
|
||||
}
|
||||
|
||||
client := NewClient(u, ts.Client())
|
||||
|
||||
resp, err := client.WebFetch(t.Context(), &WebFetchRequest{URL: "https://ollama.com"})
|
||||
if err != nil {
|
||||
t.Fatalf("WebFetch returned error: %v", err)
|
||||
}
|
||||
|
||||
if resp.Title != "Ollama" {
|
||||
t.Fatalf("unexpected title: %s", resp.Title)
|
||||
}
|
||||
|
||||
if len(resp.Links) != 1 || resp.Links[0] != "https://ollama.com/models" {
|
||||
t.Fatalf("unexpected links: %v", resp.Links)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ import (
|
|||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
|
|
@ -139,6 +140,11 @@ func TestRoutes(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var (
|
||||
searchRequests []api.WebSearchRequest
|
||||
fetchRequests []api.WebFetchRequest
|
||||
)
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "Version Handler",
|
||||
|
|
@ -455,6 +461,69 @@ func TestRoutes(t *testing.T) {
|
|||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Web Search Handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/web_search",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
searchRequests = nil
|
||||
payload := api.WebSearchRequest{Query: "cats", MaxResults: 2}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var out api.WebSearchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if len(out.Results) != 1 || out.Results[0].Title != "Result" {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
if len(searchRequests) != 1 {
|
||||
t.Fatalf("expected 1 forwarded request, got %d", len(searchRequests))
|
||||
}
|
||||
if searchRequests[0].Query != "cats" || searchRequests[0].MaxResults != 2 {
|
||||
t.Fatalf("unexpected forwarded request: %+v", searchRequests[0])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "Web Fetch Handler",
|
||||
Method: http.MethodPost,
|
||||
Path: "/api/web_fetch",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
fetchRequests = nil
|
||||
payload := api.WebFetchRequest{URL: "https://example.com"}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to marshal request: %v", err)
|
||||
}
|
||||
req.Body = io.NopCloser(bytes.NewReader(data))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.StatusCode)
|
||||
}
|
||||
var out api.WebFetchResponse
|
||||
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
|
||||
t.Fatalf("failed to decode response: %v", err)
|
||||
}
|
||||
if out.Title != "Example" || len(out.Links) != 1 {
|
||||
t.Fatalf("unexpected response: %+v", out)
|
||||
}
|
||||
if len(fetchRequests) != 1 || fetchRequests[0].URL != "https://example.com" {
|
||||
t.Fatalf("unexpected forwarded request: %+v", fetchRequests)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "openai retrieve model handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
|
|
@ -513,6 +582,41 @@ func TestRoutes(t *testing.T) {
|
|||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
|
||||
remoteSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/api/web_search":
|
||||
var req api.WebSearchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
searchRequests = append(searchRequests, req)
|
||||
resp := api.WebSearchResponse{Results: []api.WebSearchResult{{Title: "Result", URL: "https://example.com", Content: "snippet"}}}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
case "/api/web_fetch":
|
||||
var req api.WebFetchRequest
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
fetchRequests = append(fetchRequests, req)
|
||||
resp := api.WebFetchResponse{Title: "Example", Content: "content", Links: []string{"https://example.com"}}
|
||||
_ = json.NewEncoder(w).Encode(resp)
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
}
|
||||
}))
|
||||
defer remoteSrv.Close()
|
||||
|
||||
remoteURL, err := url.Parse(remoteSrv.URL)
|
||||
if err != nil {
|
||||
t.Fatalf("parse remote server URL: %v", err)
|
||||
}
|
||||
|
||||
origWebClient := webServiceClient
|
||||
webServiceClient = api.NewClient(remoteURL, remoteSrv.Client())
|
||||
t.Cleanup(func() { webServiceClient = origWebClient })
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
|
|
|
|||
Loading…
Reference in New Issue