update tests
This commit is contained in:
parent
4f45f39bc6
commit
b91c1f6749
|
|
@ -76,9 +76,20 @@ var lowVRAMThreshold uint64 = 20 * format.GibiByte
|
||||||
var mode string = gin.DebugMode
|
var mode string = gin.DebugMode
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
addr net.Addr
|
addr net.Addr
|
||||||
sched *Scheduler
|
sched *Scheduler
|
||||||
lowVRAM bool
|
lowVRAM bool
|
||||||
|
cloudBaseURL *url.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) webServiceBase() *url.URL {
|
||||||
|
defaultWebServiceURL := url.URL{Scheme: "https", Host: "ollama.com"}
|
||||||
|
if s != nil && s.cloudBaseURL != nil {
|
||||||
|
u := *s.cloudBaseURL
|
||||||
|
return &u
|
||||||
|
}
|
||||||
|
u := defaultWebServiceURL
|
||||||
|
return &u
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
|
@ -788,7 +799,7 @@ func (s *Server) WebSearchHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
|
||||||
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
|
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var authError api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
|
|
@ -835,7 +846,7 @@ func (s *Server) WebFetchHandler(c *gin.Context) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
webServiceClient := api.NewClient(&url.URL{Scheme: "https", Host: "ollama.com"}, http.DefaultClient)
|
webServiceClient := api.NewClient(s.webServiceBase(), http.DefaultClient)
|
||||||
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
|
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var authError api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
|
|
|
||||||
|
|
@ -92,10 +92,6 @@ func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||||
|
|
||||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||||
|
|
||||||
type roundTripFunc func(*http.Request) (*http.Response, error)
|
|
||||||
|
|
||||||
func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { return f(r) }
|
|
||||||
|
|
||||||
func TestRoutes(t *testing.T) {
|
func TestRoutes(t *testing.T) {
|
||||||
// Disable authentication for tests to avoid issues with missing private keys
|
// Disable authentication for tests to avoid issues with missing private keys
|
||||||
t.Setenv("OLLAMA_AUTH", "false")
|
t.Setenv("OLLAMA_AUTH", "false")
|
||||||
|
|
@ -620,41 +616,9 @@ func TestRoutes(t *testing.T) {
|
||||||
t.Fatalf("parse remote server URL: %v", err)
|
t.Fatalf("parse remote server URL: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
origTransport := http.DefaultTransport
|
s := &Server{
|
||||||
remoteClient := remoteSrv.Client()
|
cloudBaseURL: remoteURL,
|
||||||
remoteTransport := remoteClient.Transport
|
|
||||||
if remoteTransport == nil {
|
|
||||||
t.Fatalf("remote client transport is nil")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
http.DefaultTransport = roundTripFunc(func(req *http.Request) (*http.Response, error) {
|
|
||||||
if req.URL == nil {
|
|
||||||
panic("unexpected nil request URL")
|
|
||||||
}
|
|
||||||
if req.URL.Host != "ollama.com" {
|
|
||||||
panic(fmt.Sprintf("unexpected outbound request to %s", req.URL))
|
|
||||||
}
|
|
||||||
|
|
||||||
clone := req.Clone(req.Context())
|
|
||||||
cloneURL := remoteURL.ResolveReference(&url.URL{
|
|
||||||
Path: req.URL.Path,
|
|
||||||
RawPath: req.URL.RawPath,
|
|
||||||
RawQuery: req.URL.RawQuery,
|
|
||||||
})
|
|
||||||
clone.URL = cloneURL
|
|
||||||
clone.Host = cloneURL.Host
|
|
||||||
|
|
||||||
return remoteTransport.RoundTrip(clone)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Cleanup(func() {
|
|
||||||
if closer, ok := remoteTransport.(interface{ CloseIdleConnections() }); ok {
|
|
||||||
closer.CloseIdleConnections()
|
|
||||||
}
|
|
||||||
http.DefaultTransport = origTransport
|
|
||||||
})
|
|
||||||
|
|
||||||
s := &Server{}
|
|
||||||
router, err := s.GenerateRoutes(rc)
|
router, err := s.GenerateRoutes(rc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to generate routes: %v", err)
|
t.Fatalf("failed to generate routes: %v", err)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue