update tests
This commit is contained in:
parent
4f45f39bc6
commit
b91c1f6749
|
|
@ -76,9 +76,20 @@ var lowVRAMThreshold uint64 = 20 * format.GibiByte
|
|||
var mode string = gin.DebugMode
|
||||
|
||||
type Server struct {
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
addr net.Addr
|
||||
sched *Scheduler
|
||||
lowVRAM bool
|
||||
cloudBaseURL *url.URL
|
||||
}
|
||||
|
||||
func (s *Server) webServiceBase() *url.URL {
|
||||
defaultWebServiceURL := url.URL{Scheme: "https", Host: "ollama.com"}
|
||||
if s != nil && s.cloudBaseURL != nil {
|
||||
u := *s.cloudBaseURL
|
||||
return &u
|
||||
}
|
||||
u := defaultWebServiceURL
|
||||
return &u
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
|
@ -788,7 +799,7 @@ func (s *Server) WebSearchHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
||||
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
|
||||
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
|
|
@ -835,7 +846,7 @@ func (s *Server) WebFetchHandler(c *gin.Context) {
|
|||
return
|
||||
}
|
||||
|
||||
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
||||
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
|
||||
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
|
|
|
|||
|
|
@ -92,10 +92,6 @@ func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
|||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
||||
|
||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
// Disable authentication for tests to avoid issues with missing private keys
|
||||
t.Setenv("OLLAMA_AUTH", "false")
|
||||
|
|
@ -620,41 +616,9 @@ func TestRoutes(t *testing.T) {
|
|||
t.Fatalf("parse remote server URL: %v", err)
|
||||
}
|
||||
|
||||
origTransport := http.DefaultTransport
|
||||
remoteClient := remoteSrv.Client()
|
||||
remoteTransport := remoteClient.Transport
|
||||
if remoteTransport == nil {
|
||||
t.Fatalf("remote client transport is nil")
|
||||
s := &Server{
|
||||
cloudBaseURL: remoteURL,
|
||||
}
|
||||
|
||||
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
||||
if req.URL == nil {
|
||||
panic("unexpected nil request URL")
|
||||
}
|
||||
if req.URL.Host != "ollama.com" {
|
||||
panic(fmt.Sprintf("unexpected outbound request to %s", req.URL))
|
||||
}
|
||||
|
||||
clone := req.Clone(req.Context())
|
||||
cloneURL := remoteURL.ResolveReference(&url.URL{
|
||||
Path: req.URL.Path,
|
||||
RawPath: req.URL.RawPath,
|
||||
RawQuery: req.URL.RawQuery,
|
||||
})
|
||||
clone.URL = cloneURL
|
||||
clone.Host = cloneURL.Host
|
||||
|
||||
return remoteTransport.RoundTrip(clone)
|
||||
})
|
||||
|
||||
t.Cleanup(func() {
|
||||
if closer, ok := remoteTransport.(interface{ CloseIdleConnections() }); ok {
|
||||
closer.CloseIdleConnections()
|
||||
}
|
||||
http.DefaultTransport = origTransport
|
||||
})
|
||||
|
||||
s := &Server{}
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
|
|
|
|||
Loading…
Reference in New Issue