routes/client: add web search and fetch
This commit is contained in:
parent
6b50f2b9cd
commit
f88174c55d
|
|
@ -414,6 +414,24 @@ func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse,
|
|||
return &resp, nil
|
||||
}
|
||||
|
||||
// WebSearch performs a web search via the Ollama server.
|
||||
func (c *Client) WebSearch(ctx context.Context, req *WebSearchRequest) (*WebSearchResponse, error) {
|
||||
var resp WebSearchResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/web_search", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// WebFetch fetches the contents of a web page via the Ollama server.
|
||||
func (c *Client) WebFetch(ctx context.Context, req *WebFetchRequest) (*WebFetchResponse, error) {
|
||||
var resp WebFetchResponse
|
||||
if err := c.do(ctx, http.MethodPost, "/api/web_fetch", req, &resp); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &resp, nil
|
||||
}
|
||||
|
||||
// Embeddings generates an embedding from a model.
|
||||
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
|
||||
var resp EmbeddingResponse
|
||||
|
|
|
|||
34
api/types.go
34
api/types.go
|
|
@ -453,6 +453,40 @@ type EmbeddingResponse struct {
|
|||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
// WebSearchRequest is the request passed to [Client.WebSearch].
|
||||
type WebSearchRequest struct {
|
||||
// Query is the search query string.
|
||||
Query string `json:"query"`
|
||||
|
||||
// MaxResults is the optional maximum number of results to return (default 5, max 10).
|
||||
MaxResults int `json:"max_results,omitempty"`
|
||||
}
|
||||
|
||||
// WebSearchResult represents a single web search result.
|
||||
type WebSearchResult struct {
|
||||
Title string `json:"title"`
|
||||
URL string `json:"url"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
// WebSearchResponse is the response from [Client.WebSearch].
|
||||
type WebSearchResponse struct {
|
||||
Results []WebSearchResult `json:"results"`
|
||||
}
|
||||
|
||||
// WebFetchRequest is the request passed to [Client.WebFetch].
|
||||
type WebFetchRequest struct {
|
||||
// URL is the address of the page to fetch.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
// WebFetchResponse is the response from [Client.WebFetch].
|
||||
type WebFetchResponse struct {
|
||||
Title string `json:"title"`
|
||||
Content string `json:"content"`
|
||||
Links []string `json:"links"`
|
||||
}
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
// Model is the model name to create.
|
||||
|
|
|
|||
110
server/routes.go
110
server/routes.go
|
|
@ -51,6 +51,17 @@ import (
|
|||
|
||||
const signinURLStr = "https://ollama.com/connect?name=%s&key=%s"
|
||||
|
||||
var (
|
||||
webServiceBase = func() *url.URL {
|
||||
u, err := url.Parse("https://ollama.com")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return u
|
||||
}()
|
||||
webServiceClient = api.NewClient(webServiceBase, http.DefaultClient)
|
||||
)
|
||||
|
||||
func shouldUseHarmony(model *Model) bool {
|
||||
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
|
||||
// heuristic to check whether the template expects to be parsed via harmony:
|
||||
|
|
@ -767,6 +778,103 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) WebSearchHandler(c *gin.Context) {
|
||||
var req api.WebSearchRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.Query = strings.TrimSpace(req.Query)
|
||||
if req.Query == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "query is required"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.MaxResults != 0 && (req.MaxResults < 1 || req.MaxResults > 10) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "max_results must be between 1 and 10"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := webServiceClient.WebSearch(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
resp = &api.WebSearchResponse{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) WebFetchHandler(c *gin.Context) {
|
||||
var req api.WebFetchRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
req.URL = strings.TrimSpace(req.URL)
|
||||
if req.URL == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "url is required"})
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := webServiceClient.WebFetch(c.Request.Context(), &req)
|
||||
if err != nil {
|
||||
var authError api.AuthorizationError
|
||||
if errors.As(err, &authError) {
|
||||
sURL, sErr := signinURL()
|
||||
if sErr != nil {
|
||||
slog.Error(sErr.Error())
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "error getting authorization details"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "signin_url": sURL})
|
||||
return
|
||||
}
|
||||
var apiError api.StatusError
|
||||
if errors.As(err, &apiError) {
|
||||
c.JSON(apiError.StatusCode, apiError)
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusBadGateway, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if resp == nil {
|
||||
resp = &api.WebFetchResponse{}
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
func (s *Server) PullHandler(c *gin.Context) {
|
||||
var req api.PullRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
|
|
@ -1447,6 +1555,8 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
|||
r.POST("/api/chat", s.ChatHandler)
|
||||
r.POST("/api/embed", s.EmbedHandler)
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
r.POST("/api/web_search", s.WebSearchHandler)
|
||||
r.POST("/api/web_fetch", s.WebFetchHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
|
|
|
|||
Loading…
Reference in New Issue