diff --git a/server/routes.go b/server/routes.go index e95badc69..647e9dac5 100644 --- a/server/routes.go +++ b/server/routes.go @@ -76,9 +76,20 @@ var lowVRAMThreshold uint64 = 20 * format.GibiByte var mode string = gin.DebugMode type Server struct { - addr net.Addr - sched *Scheduler - lowVRAM bool + addr net.Addr + sched *Scheduler + lowVRAM bool + cloudBaseURL *url.URL +} + +func (s *Server) webServiceBase() *url.URL { + defaultWebServiceURL := url.URL{Scheme: "https", Host: "ollama.com"} + if s != nil && s.cloudBaseURL != nil { + u := *s.cloudBaseURL + return &u + } + u := defaultWebServiceURL + return &u } func init() { @@ -788,7 +799,7 @@ func (s *Server) WebSearchHandler(c *gin.Context) { return } - webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient) + webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient) resp, err := webServiceClient.WebSearch(c.Request.Context(), &req) if err != nil { var authError api.AuthorizationError @@ -835,7 +846,7 @@ func (s *Server) WebFetchHandler(c *gin.Context) { return } - webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient) + webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient) resp, err := webServiceClient.WebFetch(c.Request.Context(), &req) if err != nil { var authError api.AuthorizationError diff --git a/server/routes_test.go b/server/routes_test.go index 68b34e418..e8a86ddfb 100644 --- a/server/routes_test.go +++ b/server/routes_test.go @@ -92,10 +92,6 @@ func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) { var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}} -type roundTripFunc func(*http.Request) (*http.Response, error) - -func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) } - func TestRoutes(t *testing.T) { // Disable authentication for tests to avoid issues with missing private keys t.Setenv("OLLAMA_AUTH", "false") @@ -620,41 +616,9 @@ func TestRoutes(t *testing.T) { t.Fatalf("parse remote server URL: %v", err) } - origTransport := http.DefaultTransport - remoteClient := remoteSrv.Client() - remoteTransport := remoteClient.Transport - if remoteTransport == nil { - t.Fatalf("remote client transport is nil") + s := &Server{ + cloudBaseURL: remoteURL, } - - http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) { - if req.URL == nil { - panic("unexpected nil request URL") - } - if req.URL.Host != "ollama.com" { - panic(fmt.Sprintf("unexpected outbound request to %s", req.URL)) - } - - clone := req.Clone(req.Context()) - cloneURL := remoteURL.ResolveReference(&url.URL{ - Path: req.URL.Path, - RawPath: req.URL.RawPath, - RawQuery: req.URL.RawQuery, - }) - clone.URL = cloneURL - clone.Host = cloneURL.Host - - return remoteTransport.RoundTrip(clone) - }) - - t.Cleanup(func() { - if closer, ok := remoteTransport.(interface{ CloseIdleConnections() }); ok { - closer.CloseIdleConnections() - } - http.DefaultTransport = origTransport - }) - - s := &Server{} router, err := s.GenerateRoutes(rc) if err != nil { t.Fatalf("failed to generate routes: %v", err)