diff --git a/api/client_test.go b/api/client_test.go index f0034e02d..b604bfe5f 100644 --- a/api/client_test.go +++ b/api/client_test.go @@ -262,3 +262,135 @@ func TestClientDo(t *testing.T) { }) } } + +func TestClientWebSearch(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/api/web_search") { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + + var req WebSearchRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + + if req.Query != "what is ollama" { + t.Fatalf("unexpected query: %s", req.Query) + } + + if req.MaxResults != 3 { + t.Fatalf("unexpected max_results: %d", req.MaxResults) + } + + resp := WebSearchResponse{ + Results: []WebSearchResult{{ + Title: "Ollama", + URL: "https://ollama.com", + Content: "Cloud models are now available...", + }}, + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to encode response: %v", err) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("parse server URL: %v", err) + } + + client := NewClient(u, ts.Client()) + + resp, err := client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama", MaxResults: 3}) + if err != nil { + t.Fatalf("WebSearch returned error: %v", err) + } + + if len(resp.Results) != 1 { + t.Fatalf("expected 1 result, got %d", len(resp.Results)) + } + + if resp.Results[0].Title != "Ollama" { + t.Fatalf("unexpected title: %s", resp.Results[0].Title) + } +} + +func TestClientWebSearchUnauthorized(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _ = json.NewEncoder(w).Encode(map[string]string{ + "signin_url": "https://ollama.com/connect", + }) + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("parse server URL: %v", err) + } + + client := NewClient(u, ts.Client()) + + _, err = client.WebSearch(t.Context(), &WebSearchRequest{Query: "what is ollama"}) + if err == nil { + t.Fatal("expected error, got nil") + } + + if _, ok := err.(AuthorizationError); !ok { + t.Fatalf("expected AuthorizationError, got %T", err) + } +} + +func TestClientWebFetch(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("expected POST, got %s", r.Method) + } + if !strings.HasSuffix(r.URL.Path, "/api/web_fetch") { + t.Fatalf("unexpected path: %s", r.URL.Path) + } + + var req WebFetchRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Fatalf("failed to decode request: %v", err) + } + + if req.URL != "https://ollama.com" { + t.Fatalf("unexpected url: %s", req.URL) + } + + resp := WebFetchResponse{ + Title: "Ollama", + Content: "Cloud models are now available...", + Links: []string{"https://ollama.com/models"}, + } + if err := json.NewEncoder(w).Encode(resp); err != nil { + t.Fatalf("failed to encode response: %v", err) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + t.Fatalf("parse server URL: %v", err) + } + + client := NewClient(u, ts.Client()) + + resp, err := client.WebFetch(t.Context(), &WebFetchRequest{URL: "https://ollama.com"}) + if err != nil { + t.Fatalf("WebFetch returned error: %v", err) + } + + if resp.Title != "Ollama" { + t.Fatalf("unexpected title: %s", resp.Title) + } + + if len(resp.Links) != 1 || resp.Links[0] != "https://ollama.com/models" { + t.Fatalf("unexpected links: %v", resp.Links) + } +} diff --git a/server/routes_test.go b/server/routes_test.go index bb7e2b7c1..c8947ac6e 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -13,6 +13,7 @@ import ( "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" @@ -139,6 +140,11 @@ func TestRoutes(t *testing.T) { } } + var ( + searchRequests []api.WebSearchRequest + fetchRequests []api.WebFetchRequest + ) + testCases := []testCase{ { Name: "Version Handler", @@ -455,6 +461,69 @@ func TestRoutes(t *testing.T) { } }, }, + { + Name: "Web Search Handler", + Method: http.MethodPost, + Path: "/api/web_search", + Setup: func(t *testing.T, req *http.Request) { + searchRequests = nil + payload := api.WebSearchRequest{Query: "cats", MaxResults: 2} + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + req.Body = io.NopCloser(bytes.NewReader(data)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *http.Response) { + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + var out api.WebSearchResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if len(out.Results) != 1 || out.Results[0].Title != "Result" { + t.Fatalf("unexpected response: %+v", out) + } + if len(searchRequests) != 1 { + t.Fatalf("expected 1 forwarded request, got %d", len(searchRequests)) + } + if searchRequests[0].Query != "cats" || searchRequests[0].MaxResults != 2 { + t.Fatalf("unexpected forwarded request: %+v", searchRequests[0]) + } + }, + }, + { + Name: "Web Fetch Handler", + Method: http.MethodPost, + Path: "/api/web_fetch", + Setup: func(t *testing.T, req *http.Request) { + fetchRequests = nil + payload := api.WebFetchRequest{URL: "https://example.com"} + data, err := json.Marshal(payload) + if err != nil { + t.Fatalf("failed to marshal request: %v", err) + } + req.Body = io.NopCloser(bytes.NewReader(data)) + req.Header.Set("Content-Type", "application/json") + }, + Expected: func(t *testing.T, resp *http.Response) { + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected status 200, got %d", resp.StatusCode) + } + var out api.WebFetchResponse + if err := json.NewDecoder(resp.Body).Decode(&out); err != nil { + t.Fatalf("failed to decode response: %v", err) + } + if out.Title != "Example" || len(out.Links) != 1 { + t.Fatalf("unexpected response: %+v", out) + } + if len(fetchRequests) != 1 || fetchRequests[0].URL != "https://example.com" { + t.Fatalf("unexpected forwarded request: %+v", fetchRequests) + } + }, + }, { Name: "openai retrieve model handler", Setup: func(t *testing.T, req *http.Request) { @@ -513,6 +582,41 @@ func TestRoutes(t *testing.T) { HTTPClient: panicOnRoundTrip, } + remoteSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/web_search": + var req api.WebSearchRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + searchRequests = append(searchRequests, req) + resp := api.WebSearchResponse{Results: []api.WebSearchResult{{Title: "Result", URL: "https://example.com", Content: "snippet"}}} + _ = json.NewEncoder(w).Encode(resp) + case "/api/web_fetch": + var req api.WebFetchRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + fetchRequests = append(fetchRequests, req) + resp := api.WebFetchResponse{Title: "Example", Content: "content", Links: []string{"https://example.com"}} + _ = json.NewEncoder(w).Encode(resp) + default: + w.WriteHeader(http.StatusNotFound) + } + })) + defer remoteSrv.Close() + + remoteURL, err := url.Parse(remoteSrv.URL) + if err != nil { + t.Fatalf("parse remote server URL: %v", err) + } + + origWebClient := webServiceClient + webServiceClient = api.NewClient(remoteURL, remoteSrv.Client()) + t.Cleanup(func() { webServiceClient = origWebClient }) + s := &Server{} router, err := s.GenerateRoutes(rc) if err != nil {