update tests

This commit is contained in:
ParthSareen 2025-10-03 14:49:49 -07:00
parent 4f45f39bc6
commit b91c1f6749
2 changed files with 18 additions and 43 deletions

View File

@ -76,9 +76,20 @@ var lowVRAMThreshold uint64 = 20 * format.GibiByte
var mode string = gin.DebugMode var mode string = gin.DebugMode
type Server struct { type Server struct {
addr net.Addr addr net.Addr
sched *Scheduler sched *Scheduler
lowVRAM bool lowVRAM bool
cloudBaseURL *url.URL
}
func (s *Server) webServiceBase() *url.URL {
defaultWebServiceURL := url.URL{Scheme: "https", Host: "ollama.com"}
if s != nil && s.cloudBaseURL != nil {
u := *s.cloudBaseURL
return &u
}
u := defaultWebServiceURL
return &u
} }
func init() { func init() {
@ -788,7 +799,7 @@ func (s *Server) WebSearchHandler(c *gin.Context) {
return return
} }
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient) webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req) resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
if err != nil { if err != nil {
var authError api.AuthorizationError var authError api.AuthorizationError
@ -835,7 +846,7 @@ func (s *Server) WebFetchHandler(c *gin.Context) {
return return
} }
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient) webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req) resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
if err != nil { if err != nil {
var authError api.AuthorizationError var authError api.AuthorizationError

View File

@ -92,10 +92,6 @@ func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}} var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
func TestRoutes(t *testing.T) { func TestRoutes(t *testing.T) {
// Disable authentication for tests to avoid issues with missing private keys // Disable authentication for tests to avoid issues with missing private keys
t.Setenv("OLLAMA_AUTH", "false") t.Setenv("OLLAMA_AUTH", "false")
@ -620,41 +616,9 @@ func TestRoutes(t *testing.T) {
t.Fatalf("parse remote server URL: %v", err) t.Fatalf("parse remote server URL: %v", err)
} }
origTransport := http.DefaultTransport s := &Server{
remoteClient := remoteSrv.Client() cloudBaseURL: remoteURL,
remoteTransport := remoteClient.Transport
if remoteTransport == nil {
t.Fatalf("remote client transport is nil")
} }
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
if req.URL == nil {
panic("unexpected nil request URL")
}
if req.URL.Host != "ollama.com" {
panic(fmt.Sprintf("unexpected outbound request to %s", req.URL))
}
clone := req.Clone(req.Context())
cloneURL := remoteURL.ResolveReference(&url.URL{
Path: req.URL.Path,
RawPath: req.URL.RawPath,
RawQuery: req.URL.RawQuery,
})
clone.URL = cloneURL
clone.Host = cloneURL.Host
return remoteTransport.RoundTrip(clone)
})
t.Cleanup(func() {
if closer, ok := remoteTransport.(interface{ CloseIdleConnections() }); ok {
closer.CloseIdleConnections()
}
http.DefaultTransport = origTransport
})
s := &Server{}
router, err := s.GenerateRoutes(rc) router, err := s.GenerateRoutes(rc)
if err != nil { if err != nil {
t.Fatalf("failed to generate routes: %v", err) t.Fatalf("failed to generate routes: %v", err)