add tests

This commit is contained in:
ParthSareen
2025-10-01 13:23:49 -07:00
parent f88174c55d
commit 03e1d64aac
2 changed files with 236 additions and 0 deletions

View File

@@ -13,6 +13,7 @@ import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"reflect"
@@ -139,6 +140,11 @@ func TestRoutes(t *testing.T) {
}
}
var (
searchRequests []api.WebSearchRequest
fetchRequests []api.WebFetchRequest
)
testCases := []testCase{
{
Name: "Version Handler",
@@ -455,6 +461,69 @@ func TestRoutes(t *testing.T) {
}
},
},
{
Name: "Web Search Handler",
Method: http.MethodPost,
Path: "/api/web_search",
Setup: func(t *testing.T, req *http.Request) {
searchRequests = nil
payload := api.WebSearchRequest{Query: "cats", MaxResults: 2}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
}
var out api.WebSearchResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if len(out.Results) != 1 || out.Results[0].Title != "Result" {
t.Fatalf("unexpected response: %+v", out)
}
if len(searchRequests) != 1 {
t.Fatalf("expected 1 forwarded request, got %d", len(searchRequests))
}
if searchRequests[0].Query != "cats" || searchRequests[0].MaxResults != 2 {
t.Fatalf("unexpected forwarded request: %+v", searchRequests[0])
}
},
},
{
Name: "Web Fetch Handler",
Method: http.MethodPost,
Path: "/api/web_fetch",
Setup: func(t *testing.T, req *http.Request) {
fetchRequests = nil
payload := api.WebFetchRequest{URL: "https://example.com"}
data, err := json.Marshal(payload)
if err != nil {
t.Fatalf("failed to marshal request: %v", err)
}
req.Body = io.NopCloser(bytes.NewReader(data))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *http.Response) {
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected status 200, got %d", resp.StatusCode)
}
var out api.WebFetchResponse
if err := json.NewDecoder(resp.Body).Decode(&out); err != nil {
t.Fatalf("failed to decode response: %v", err)
}
if out.Title != "Example" || len(out.Links) != 1 {
t.Fatalf("unexpected response: %+v", out)
}
if len(fetchRequests) != 1 || fetchRequests[0].URL != "https://example.com" {
t.Fatalf("unexpected forwarded request: %+v", fetchRequests)
}
},
},
{
Name: "openai retrieve model handler",
Setup: func(t *testing.T, req *http.Request) {
@@ -513,6 +582,41 @@ func TestRoutes(t *testing.T) {
HTTPClient: panicOnRoundTrip,
}
remoteSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/api/web_search":
var req api.WebSearchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
searchRequests = append(searchRequests, req)
resp := api.WebSearchResponse{Results: []api.WebSearchResult{{Title: "Result", URL: "https://example.com", Content: "snippet"}}}
_ = json.NewEncoder(w).Encode(resp)
case "/api/web_fetch":
var req api.WebFetchRequest
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
fetchRequests = append(fetchRequests, req)
resp := api.WebFetchResponse{Title: "Example", Content: "content", Links: []string{"https://example.com"}}
_ = json.NewEncoder(w).Encode(resp)
default:
w.WriteHeader(http.StatusNotFound)
}
}))
defer remoteSrv.Close()
remoteURL, err := url.Parse(remoteSrv.URL)
if err != nil {
t.Fatalf("parse remote server URL: %v", err)
}
origWebClient := webServiceClient
webServiceClient = api.NewClient(remoteURL, remoteSrv.Client())
t.Cleanup(func() { webServiceClient = origWebClient })
s := &Server{}
router, err := s.GenerateRoutes(rc)
if err != nil {