ollama/client/client.go

189 lines
4.8 KiB
Go

package client
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"iter"
"maps"
"net/http"
"net/url"
"runtime"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/version"
)
type Error struct {
Status string
StatusCode int
}
func (e Error) Error() string {
return e.Status
}
type Client struct {
baseURL *url.URL
header http.Header
}
// WithBaseURL sets the base URL. It panics if it is invalid.
func WithBaseURL(s string) func(*Client) {
return func(c *Client) {
parsed, err := url.Parse(s)
if err != nil {
panic(err)
}
c.baseURL = parsed
}
}
// WithHeader sets custom HTTP headers.
func WithHeader(h http.Header) func(*Client) {
return func(c *Client) {
c.header = h
}
}
func New(opts ...func(*Client) error) *Client {
c := Client{
baseURL: envconfig.Host(),
header: http.Header{
"Content-Type": {"application/json"},
"User-Agent": {userAgent},
},
}
for _, opt := range opts {
opt(&c)
}
return &c
}
func (c *Client) Ping(ctx context.Context) error {
_, err := do[struct{}](c, ctx, http.MethodHead, "/", nil)
return err
}
func (c *Client) Chat(ctx context.Context, r api.ChatRequest) (iter.Seq2[api.ChatResponse, error], error) {
return doSeq[api.ChatResponse](c, ctx, http.MethodPost, "/api/chat", r)
}
// Pull downloads a model from a remote repository to the Ollama server.
func (c *Client) Pull(ctx context.Context, r api.PullRequest) (iter.Seq2[api.ProgressResponse, error], error) {
return doSeq[api.ProgressResponse](c, ctx, http.MethodPost, "/api/pull", r)
}
// Push uploads a model from the Ollama server to a remote repository.
func (c *Client) Push(ctx context.Context, r api.PushRequest) (iter.Seq2[api.ProgressResponse, error], error) {
return doSeq[api.ProgressResponse](c, ctx, http.MethodPost, "/api/push", r)
}
// Create builds a new model on the Ollama server.
func (c *Client) Create(ctx context.Context, r api.CreateRequest) (iter.Seq2[api.ProgressResponse, error], error) {
return doSeq[api.ProgressResponse](c, ctx, http.MethodPost, "/api/create", r)
}
// List returns the list of models from the Ollama server.
func (c *Client) List(ctx context.Context) (api.ListResponse, error) {
return do[api.ListResponse](c, ctx, http.MethodGet, "/api/tags", nil)
}
// Delete removes a model from the Ollama server.
func (c *Client) Delete(ctx context.Context, r api.DeleteRequest) error {
_, err := do[struct{}](c, ctx, http.MethodDelete, "/api/delete", r)
return err
}
// Version returns the Ollama server version.
func (c *Client) Version(ctx context.Context) (string, error) {
version, err := do[struct {
Version string `json:"version"`
}](c, ctx, "GET", "/api/version", nil)
if err != nil {
return "", err
}
return version.Version, nil
}
var userAgent = fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version())
// do sends the specified HTTP request and returns the raw HTTP response. header are merged with the client's default headers.
func (c *Client) do(ctx context.Context, method, path string, body any, header http.Header) (*http.Response, error) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
return nil, err
}
r, err := http.NewRequestWithContext(ctx, method, c.baseURL.JoinPath(path).String(), &b)
if err != nil {
return nil, err
}
// copy headers into the request in order. later headers override earlier ones.
for _, header := range []http.Header{c.header, header} {
maps.Copy(r.Header, header)
}
w, err := http.DefaultClient.Do(r)
if err != nil {
return nil, err
}
if w.StatusCode >= 400 {
return nil, Error{
Status: w.Status,
StatusCode: w.StatusCode,
}
}
return w, nil
}
// do sends the specified HTTP request and returns the JSON response as type T
func do[T any](c *Client, ctx context.Context, method, path string, body any) (t T, err error) {
w, err := c.do(ctx, method, path, body, http.Header{"Accept": {"application/json"}})
if err != nil {
return t, err
}
defer w.Body.Close()
if w.ContentLength > 0 && method != http.MethodHead {
if err := json.NewDecoder(w.Body).Decode(&t); err != nil {
return t, err
}
}
return t, nil
}
// doSeq sends the specified HTTP request and returns an iterator that yields the JSON response chunks as type T
func doSeq[T any](c *Client, ctx context.Context, method, path string, body any) (iter.Seq2[T, error], error) {
w, err := c.do(ctx, method, path, body, http.Header{"Accept": {"application/jsonl", "application/x-ndjson"}})
if err != nil {
return nil, err
}
return func(yield func(T, error) bool) {
defer w.Body.Close()
bts := make([]byte, 0, 512<<10)
s := bufio.NewScanner(w.Body)
s.Buffer(bts, len(bts))
for s.Scan() {
var t T
if err := json.Unmarshal(s.Bytes(), &t); err != nil {
yield(t, err)
break
}
if !yield(t, nil) {
break
}
}
}, nil
}