From f88174c55de565d3092aeff80a531c1bba37c049 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 1 Oct 2025 13:08:57 -0700 Subject: [PATCH] routes/client: add web search and fetch --- api/client.go | 18 ++++++++ api/types.go | 34 +++++++++++++++ server/routes.go | 110 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 162 insertions(+) diff --git a/api/client.go b/api/client.go index 0d4c97ba9..2791da270 100644 --- a/api/client.go +++ b/api/client.go @@ -414,6 +414,24 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, return &resp, nil } +// WebSearch performs a web search via the Ollama server. +func (c *Client) WebSearch(ctx context.Context, req *WebSearchRequest) (*WebSearchResponse, error) { + var resp WebSearchResponse + if err := c.do(ctx, http.MethodPost, "/api/web_search", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + +// WebFetch fetches the contents of a web page via the Ollama server. +func (c *Client) WebFetch(ctx context.Context, req *WebFetchRequest) (*WebFetchResponse, error) { + var resp WebFetchResponse + if err := c.do(ctx, http.MethodPost, "/api/web_fetch", req, &resp); err != nil { + return nil, err + } + return &resp, nil +} + // Embeddings generates an embedding from a model. func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { var resp EmbeddingResponse diff --git a/api/types.go b/api/types.go index 8cc7752ca..5c7a3e36f 100644 --- a/api/types.go +++ b/api/types.go @@ -453,6 +453,40 @@ type EmbeddingResponse struct { Embedding []float64 `json:"embedding"` } +// WebSearchRequest is the request passed to [Client.WebSearch]. +type WebSearchRequest struct { + // Query is the search query string. + Query string `json:"query"` + + // MaxResults is the optional maximum number of results to return (default 5, max 10). + MaxResults int `json:"max_results,omitempty"` +} + +// WebSearchResult represents a single web search result. +type WebSearchResult struct { + Title string `json:"title"` + URL string `json:"url"` + Content string `json:"content"` +} + +// WebSearchResponse is the response from [Client.WebSearch]. +type WebSearchResponse struct { + Results []WebSearchResult `json:"results"` +} + +// WebFetchRequest is the request passed to [Client.WebFetch]. +type WebFetchRequest struct { + // URL is the address of the page to fetch. + URL string `json:"url"` +} + +// WebFetchResponse is the response from [Client.WebFetch]. +type WebFetchResponse struct { + Title string `json:"title"` + Content string `json:"content"` + Links []string `json:"links"` +} + // CreateRequest is the request passed to [Client.Create]. type CreateRequest struct { // Model is the model name to create. diff --git a/server/routes.go b/server/routes.go index 343411b92..e18881816 100644 --- a/server/routes.go +++ b/server/routes.go @@ -51,6 +51,17 @@ import ( const signinURLStr = "https://ollama.com/connect?name=%s&key=%s" +var ( + webServiceBase = func() *url.URL { + u, err := url.Parse("https://ollama.com") + if err != nil { + panic(err) + } + return u + }() + webServiceClient = api.NewClient(webServiceBase, http.DefaultClient) +) + func shouldUseHarmony(model *Model) bool { if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { // heuristic to check whether the template expects to be parsed via harmony: @@ -767,6 +778,103 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) { c.JSON(http.StatusOK, resp) } +func (s *Server) WebSearchHandler(c *gin.Context) { + var req api.WebSearchRequest + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.Query = strings.TrimSpace(req.Query) + if req.Query == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "query is required"}) + return + } + + if req.MaxResults != 0 && (req.MaxResults < 1 || req.MaxResults > 10) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "max_results must be between 1 and 10"}) + return + } + + resp, err := webServiceClient.WebSearch(c.Request.Context(), &req) + if err != nil { + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) + return + } + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + if resp == nil { + resp = &api.WebSearchResponse{} + } + + c.JSON(http.StatusOK, resp) +} + +func (s *Server) WebFetchHandler(c *gin.Context) { + var req api.WebFetchRequest + if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) + return + } else if err != nil { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + req.URL = strings.TrimSpace(req.URL) + if req.URL == "" { + c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "url is required"}) + return + } + + resp, err := webServiceClient.WebFetch(c.Request.Context(), &req) + if err != nil { + var authError api.AuthorizationError + if errors.As(err, &authError) { + sURL, sErr := signinURL() + if sErr != nil { + slog.Error(sErr.Error()) + c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"}) + return + } + + c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL}) + return + } + var apiError api.StatusError + if errors.As(err, &apiError) { + c.JSON(apiError.StatusCode, apiError) + return + } + c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()}) + return + } + + if resp == nil { + resp = &api.WebFetchResponse{} + } + + c.JSON(http.StatusOK, resp) +} + func (s *Server) PullHandler(c *gin.Context) { var req api.PullRequest err := c.ShouldBindJSON(&req) @@ -1447,6 +1555,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) { r.POST("/api/chat", s.ChatHandler) r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embeddings", s.EmbeddingsHandler) + r.POST("/api/web_search", s.WebSearchHandler) + r.POST("/api/web_fetch", s.WebFetchHandler) // Inference (OpenAI compatibility) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)