Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
ad7e641815 add batch embeddings 2024-04-26 20:13:33 -04:00
31 changed files with 507 additions and 1259 deletions

View File

@@ -51,7 +51,7 @@ Here are some example models that can be downloaded:
| ------------------ | ---------- | ----- | ------------------------------ |
| Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Phi-3 | 3.8B | 2.3GB | `ollama run phi3` |
| Phi-3 | 3,8B | 2.3GB | `ollama run phi3` |
| Mistral | 7B | 4.1GB | `ollama run mistral` |
| Neural Chat | 7B | 4.1GB | `ollama run neural-chat` |
| Starling | 7B | 4.1GB | `ollama run starling-lm` |

View File

@@ -18,7 +18,6 @@ import (
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@@ -58,36 +57,12 @@ func checkError(resp *http.Response, body []byte) error {
// If the variable is not specified, a default ollama host and port will be
// used.
func ClientFromEnvironment() (*Client, error) {
ollamaHost, err := GetOllamaHost()
if err != nil {
return nil, err
}
return &Client{
base: &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
},
http: http.DefaultClient,
}, nil
}
type OllamaHost struct {
Scheme string
Host string
Port string
}
func GetOllamaHost() (OllamaHost, error) {
defaultPort := "11434"
hostVar := os.Getenv("OLLAMA_HOST")
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
scheme, hostport, ok := strings.Cut(hostVar, "://")
scheme, hostport, ok := strings.Cut(os.Getenv("OLLAMA_HOST"), "://")
switch {
case !ok:
scheme, hostport = "http", hostVar
scheme, hostport = "http", os.Getenv("OLLAMA_HOST")
case scheme == "http":
defaultPort = "80"
case scheme == "https":
@@ -107,14 +82,12 @@ func GetOllamaHost() (OllamaHost, error) {
}
}
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
return OllamaHost{}, ErrInvalidHostPort
}
return OllamaHost{
Scheme: scheme,
Host: host,
Port: port,
return &Client{
base: &url.URL{
Scheme: scheme,
Host: net.JoinHostPort(host, port),
},
http: http.DefaultClient,
}, nil
}

View File

@@ -1,12 +1,6 @@
package api
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/assert"
)
import "testing"
func TestClientFromEnvironment(t *testing.T) {
type testCase struct {
@@ -46,40 +40,4 @@ func TestClientFromEnvironment(t *testing.T) {
}
})
}
hostTestCases := map[string]*testCase{
"empty": {value: "", expect: "127.0.0.1:11434"},
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
"only port": {value: ":1234", expect: ":1234"},
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
"hostname": {value: "example.com", expect: "example.com:11434"},
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
"zero port": {value: ":0", expect: ":0"},
"too large port": {value: ":66000", err: ErrInvalidHostPort},
"too small port": {value: ":-1", err: ErrInvalidHostPort},
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
}
for k, v := range hostTestCases {
t.Run(k, func(t *testing.T) {
t.Setenv("OLLAMA_HOST", v.value)
oh, err := GetOllamaHost()
if err != v.err {
t.Fatalf("expected %s, got %s", v.err, err)
}
if err == nil {
host := net.JoinHostPort(oh.Host, oh.Port)
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
}
})
}
}

View File

@@ -159,15 +159,17 @@ type Runner struct {
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt,omitempty"`
PromptBatch []string `json:"prompt_batch,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
Embedding []float64 `json:"embedding,omitempty"`
EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
}
type CreateRequest struct {
@@ -309,7 +311,6 @@ func (m *Metrics) Summary() {
}
var ErrInvalidOpts = errors.New("invalid options")
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
func (opts *Options) FromMap(m map[string]interface{}) error {
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct

View File

@@ -43,36 +43,37 @@ func getCLIFullPath(command string) string {
return command
}
func start(ctx context.Context, command string) (*exec.Cmd, error) {
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
logDir := filepath.Dir(ServerLogFile)
_, err := os.Stat(logDir)
if errors.Is(err, os.ErrNotExist) {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return done, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
cmd := getCmd(ctx, getCLIFullPath(command))
// send stdout and stderr to a file
stdout, err := cmd.StdoutPipe()
if err != nil {
return nil, fmt.Errorf("failed to spawn server stdout pipe: %w", err)
return done, fmt.Errorf("failed to spawn server stdout pipe %s", err)
}
stderr, err := cmd.StderrPipe()
if err != nil {
return nil, fmt.Errorf("failed to spawn server stderr pipe: %w", err)
return done, fmt.Errorf("failed to spawn server stderr pipe %s", err)
}
stdin, err := cmd.StdinPipe()
if err != nil {
return done, fmt.Errorf("failed to spawn server stdin pipe %s", err)
}
// TODO - rotation
logFile, err := os.OpenFile(ServerLogFile, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0755)
if err != nil {
return nil, fmt.Errorf("failed to create server log: %w", err)
return done, fmt.Errorf("failed to create server log %w", err)
}
logDir := filepath.Dir(ServerLogFile)
_, err = os.Stat(logDir)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
}
if err := os.MkdirAll(logDir, 0o755); err != nil {
return nil, fmt.Errorf("create ollama server log dir %s: %v", logDir, err)
}
}
go func() {
defer logFile.Close()
io.Copy(logFile, stdout) //nolint:errcheck
@@ -116,33 +117,19 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
// run the command and wait for it to finish
if err := cmd.Start(); err != nil {
return nil, fmt.Errorf("failed to start server %w", err)
return done, fmt.Errorf("failed to start server %w", err)
}
if cmd.Process != nil {
slog.Info(fmt.Sprintf("started ollama server with pid %d", cmd.Process.Pid))
}
slog.Info(fmt.Sprintf("ollama server logs %s", ServerLogFile))
return cmd, nil
}
func SpawnServer(ctx context.Context, command string) (chan int, error) {
done := make(chan int)
go func() {
// Keep the server running unless we're shuttind down the app
crashCount := 0
for {
slog.Info("starting server...")
cmd, err := start(ctx, command)
if err != nil {
crashCount++
slog.Error(fmt.Sprintf("failed to start server %s", err))
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
continue
}
cmd.Wait() //nolint:errcheck
stdin.Close()
var code int
if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode()
@@ -156,12 +143,15 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
default:
crashCount++
slog.Warn(fmt.Sprintf("server crash %d - exit code %d - respawning", crashCount, code))
time.Sleep(500 * time.Millisecond * time.Duration(crashCount))
break
time.Sleep(500 * time.Millisecond)
if err := cmd.Start(); err != nil {
slog.Error(fmt.Sprintf("failed to restart server %s", err))
// Keep trying, but back off if we keep failing
time.Sleep(time.Duration(crashCount) * time.Second)
}
}
}
}()
return done, nil
}

View File

@@ -88,8 +88,8 @@ DialogFontSize=12
[Files]
Source: ".\app.exe"; DestDir: "{app}"; DestName: "{#MyAppExeName}" ; Flags: ignoreversion 64bit
Source: "..\ollama.exe"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-{#ARCH}\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\windows-amd64\*.dll"; DestDir: "{app}"; Flags: ignoreversion 64bit
Source: "..\dist\windows-amd64\ollama_runners\*"; DestDir: "{app}\ollama_runners"; Flags: ignoreversion 64bit recursesubdirs
Source: "..\dist\ollama_welcome.ps1"; DestDir: "{app}"; Flags: ignoreversion
Source: ".\assets\app.ico"; DestDir: "{app}"; Flags: ignoreversion
#if DirExists("..\dist\windows-amd64\rocm")

View File

@@ -10,44 +10,12 @@ import (
"log/slog"
"os"
"path/filepath"
"strings"
"golang.org/x/crypto/ssh"
)
const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil {
return "", err
}
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey())
return strings.TrimSpace(string(publicKey)), nil
}
func NewNonce(r io.Reader, length int) (string, error) {
nonce := make([]byte, length)
if _, err := io.ReadFull(r, nonce); err != nil {
@@ -58,11 +26,13 @@ func NewNonce(r io.Reader, length int) (string, error) {
}
func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath()
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))

View File

@@ -1,95 +0,0 @@
package apitype
import (
"cmp"
"encoding/json"
"log/slog"
"net/url"
"slices"
)
type Manifest struct {
Layers []*Layer `json:"layers"`
}
type CompletePart struct {
URL string `json:"url"` // contains partNumber and uploadId from server
ETag string `json:"etag"`
}
func queryFromString(s string) url.Values {
u, err := url.Parse(s)
if err != nil {
return nil
}
return u.Query()
}
func (cp *CompletePart) Compare(o *CompletePart) int {
qa := queryFromString(cp.URL)
qb := queryFromString(o.URL)
return cmp.Or(
cmp.Compare(qa.Get("partNumber"), qb.Get("partNumber")),
cmp.Compare(qa.Get("uploadId"), qb.Get("uploadId")),
cmp.Compare(cp.ETag, o.ETag),
)
}
func SortCompleteParts(a []*CompletePart) {
slices.SortFunc(a, (*CompletePart).Compare)
}
type Layer struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
// If present, URL is a remote location of the layer for fetching.
URL string `json:"url,omitempty"`
}
func (l *Layer) LogValue() slog.Value {
return slog.GroupValue(
slog.String("digest", l.Digest),
slog.String("mediaType", l.MediaType),
slog.Int64("size", l.Size),
slog.String("url", l.URL),
)
}
type PushRequest struct {
Name string `json:"ref"`
Manifest json.RawMessage `json:"manifest,omitempty"`
// Parts is a list of upload parts that the client upload in the previous
// push.
CompleteParts []*CompletePart `json:"part_uploads"`
}
type Need struct {
Digest string `json:"digest"`
Start int64 `json:"start"`
End int64 `json:"end"`
// URL is the url to PUT the layer to.
//
// Clients must include it as the URL, along with the ETag in the
// response headers from the PUT request, in the next push request
// in the Uploaded field.
URL string `json:"url"`
}
type PushResponse struct {
// Needs is a list of digests that the client needs to push before
// repushing the manifest.
Needs []*Need `json:"requirements,omitempty"`
}
type PullResponse struct {
// Name is the name of the model being pulled.
Name string `json:"name"`
// Manifest is the manifest of the model being pulled.
Manifest *Manifest `json:"manifest"`
}

View File

@@ -1,421 +0,0 @@
package registry
import (
"cmp"
"context"
"encoding/json"
"encoding/xml"
"errors"
"fmt"
"io"
"iter"
"log/slog"
"net/http"
"net/url"
"os"
"sync"
"github.com/ollama/ollama/client/ollama"
"github.com/ollama/ollama/client/registry/apitype"
"github.com/ollama/ollama/types/model"
"golang.org/x/exp/constraints"
"golang.org/x/sync/errgroup"
)
// Errors
var (
ErrLayerNotFound = errors.New("layer not found")
)
type Client struct {
BaseURL string
Logger *slog.Logger
// NameFill is a string that is used to fill in the missing parts of
// a name when it is not fully qualified. It is used to make a name
// fully qualified before pushing or pulling it. The default is
// "registry.ollama.ai/library/_:latest".
//
// Most users can ignore this field. It is intended for use by
// clients that need to push or pull names to registries other than
// registry.ollama.ai, and for testing.
NameFill string
}
func (c *Client) log() *slog.Logger {
return cmp.Or(c.Logger, slog.Default())
}
func (c *Client) oclient() *ollama.Client {
return &ollama.Client{
BaseURL: c.BaseURL,
}
}
type ReadAtSeekCloser interface {
io.ReaderAt
io.Seeker
io.Closer
}
type Cache interface {
// LayerFile returns the absolute file path to the layer file for
// the given model digest.
//
// If the digest is invalid, or the layer does not exist, the empty
// string is returned.
LayerFile(model.Digest) string
// OpenLayer opens the layer file for the given model digest and
// returns it, or an if any. The caller is responsible for closing
// the returned file.
OpenLayer(model.Digest) (ReadAtSeekCloser, error)
// PutLayerFile moves the layer file at fromPath to the cache for
// the given model digest. It is a hack intended to short circuit a
// file copy operation.
//
// The file returned is expected to exist for the lifetime of the
// cache.
//
// TODO(bmizerany): remove this; find a better way. Once we move
// this into a build package, we should be able to get rid of this.
PutLayerFile(_ model.Digest, fromPath string) error
// SetManifestData sets the provided manifest data for the given
// model name. If the manifest data is empty, the manifest is
// removed. If the manifeest exists, it is overwritten.
//
// It is an error to call SetManifestData with a name that is not
// complete.
SetManifestData(model.Name, []byte) error
// ManifestData returns the manifest data for the given model name.
//
// If the name incomplete, or the manifest does not exist, the empty
// string is returned.
ManifestData(name model.Name) []byte
}
// Pull pulls the manifest for name, and downloads any of its required
// layers that are not already in the cache. It returns an error if any part
// of the process fails, specifically:
func (c *Client) Pull(ctx context.Context, cache Cache, name string) error {
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: pull: invalid name: %s", name)
}
log := c.log().With("name", name)
pr, err := ollama.Do[*apitype.PullResponse](ctx, c.oclient(), "GET", "/v1/pull/"+name, nil)
if err != nil {
return fmt.Errorf("ollama: pull: %w: %s", err, name)
}
if pr.Manifest == nil || len(pr.Manifest.Layers) == 0 {
return fmt.Errorf("ollama: pull: invalid manifest: %s: no layers found", name)
}
// download required layers we do not already have
for _, l := range pr.Manifest.Layers {
d, err := model.ParseDigest(l.Digest)
if err != nil {
return fmt.Errorf("ollama: reading manifest: %w: %s", err, l.Digest)
}
if cache.LayerFile(d) != "" {
continue
}
err = func() error {
log := log.With("digest", l.Digest, "mediaType", l.MediaType, "size", l.Size)
log.Debug("starting download")
// TODO(bmizerany): stop using temp which might not
// be on same device as cache.... instead let cache
// give us a place to store parts...
tmpFile, err := os.CreateTemp("", "ollama-download-")
if err != nil {
return err
}
defer func() {
tmpFile.Close()
os.Remove(tmpFile.Name()) // in case we fail before committing
}()
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(8) // TODO(bmizerany): make this configurable
// TODO(bmizerany): make chunk size configurable
const chunkSize = 50 * 1024 * 1024 // 50MB
chunks(l.Size, chunkSize)(func(_ int, rng chunkRange[int64]) bool {
g.Go(func() (err error) {
defer func() {
if err == nil {
return
}
safeURL := redactAmzSignature(l.URL)
err = fmt.Errorf("%w: %s %s bytes=%s: %s", err, pr.Name, l.Digest, rng, safeURL)
}()
log.Debug("downloading", "range", rng)
// TODO(bmizerany): retry
// TODO(bmizerany): use real http client
// TODO(bmizerany): resumable
// TODO(bmizerany): multipart download
req, err := http.NewRequestWithContext(ctx, "GET", l.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", "bytes="+rng.String())
res, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
log.Debug("unexpected non-2XX status code", "status", res.StatusCode)
return fmt.Errorf("unexpected status code fetching layer: %d", res.StatusCode)
}
if res.ContentLength != rng.Size() {
return fmt.Errorf("unexpected content length: %d", res.ContentLength)
}
w := io.NewOffsetWriter(tmpFile, rng.Start)
_, err = io.Copy(w, res.Body)
return err
})
return true
})
if err := g.Wait(); err != nil {
return err
}
tmpFile.Close() // release our hold on the file before moving it
return cache.PutLayerFile(d, tmpFile.Name())
}()
if err != nil {
return fmt.Errorf("ollama: pull: %w", err)
}
}
// do not store the presigned URLs in the cache
for i := range pr.Manifest.Layers {
pr.Manifest.Layers[i].URL = ""
}
data, err := json.Marshal(pr.Manifest)
if err != nil {
return err
}
// TODO(bmizerany): remove dep on model.Name
return cache.SetManifestData(mn, data)
}
type nopSeeker struct {
io.Reader
}
func (nopSeeker) Seek(int64, int) (int64, error) {
return 0, nil
}
func parseNameFill(name, fill string) model.Name {
fill = cmp.Or(fill, "bllamo.com/library/_:latest")
f := model.ParseNameBare(fill)
if !f.IsFullyQualified() {
panic(fmt.Errorf("invalid fill: %q", fill))
}
return model.Merge(model.ParseNameBare(name), f)
}
// Push pushes a manifest to the server and responds to the server's
// requests for layer uploads, if any, and finally commits the manifest for
// name. It returns an error if any part of the process fails, specifically:
//
// If the server requests layers not found in the cache, ErrLayerNotFound is
// returned.
func (c *Client) Push(ctx context.Context, cache Cache, name string) error {
mn := parseNameFill(name, c.NameFill)
if !mn.IsFullyQualified() {
return fmt.Errorf("ollama: push: invalid name: %s", name)
}
manifest := cache.ManifestData(mn)
if len(manifest) == 0 {
return fmt.Errorf("manifest not found: %s", name)
}
var mu sync.Mutex
var completed []*apitype.CompletePart
push := func() (*apitype.PushResponse, error) {
v, err := ollama.Do[*apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Name: name,
Manifest: manifest,
CompleteParts: completed,
})
if err != nil {
return nil, fmt.Errorf("Do: %w", err)
}
return v, nil
}
pr, err := push()
if err != nil {
return err
}
var g errgroup.Group
for _, need := range pr.Needs {
g.Go(func() error {
nd, err := model.ParseDigest(need.Digest)
if err != nil {
return fmt.Errorf("ParseDigest: %w: %s", err, need.Digest)
}
f, err := cache.OpenLayer(nd)
if err != nil {
return fmt.Errorf("OpenLayer: %w: %s", err, need.Digest)
}
defer f.Close()
c.log().Info("pushing layer", "digest", need.Digest, "start", need.Start, "end", need.End)
cp, err := PushLayer(ctx, f, need.URL, need.Start, need.End)
if err != nil {
return fmt.Errorf("PushLayer: %w: %s", err, need.Digest)
}
mu.Lock()
completed = append(completed, cp)
mu.Unlock()
return nil
})
}
if err := g.Wait(); err != nil {
return fmt.Errorf("Push: Required: %w", err)
}
if len(completed) > 0 {
pr, err := push()
if err != nil {
return err
}
if len(pr.Needs) > 0 {
var errs []error
for _, r := range pr.Needs {
errs = append(errs, fmt.Errorf("Push: server failed to find part: %q", r.Digest))
}
return errors.Join(errs...)
}
}
return cache.SetManifestData(mn, manifest)
}
func PushLayer(ctx context.Context, body io.ReaderAt, url string, start, end int64) (*apitype.CompletePart, error) {
if start < 0 || end < start {
return nil, errors.New("start must satisfy 0 <= start <= end")
}
file := io.NewSectionReader(body, start, end-start+1)
req, err := http.NewRequest("PUT", url, file)
if err != nil {
return nil, err
}
req.ContentLength = end - start + 1
// TODO(bmizerany): take content type param
req.Header.Set("Content-Type", "text/plain")
if start != 0 || end != 0 {
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", start, end))
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
e := parseS3Error(res)
return nil, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
}
cp := &apitype.CompletePart{
URL: url,
ETag: res.Header.Get("ETag"),
// TODO(bmizerany): checksum
}
return cp, nil
}
type s3Error struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestId string `xml:"RequestId"`
}
func (e *s3Error) Error() string {
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
}
// parseS3Error parses an XML error response from S3.
func parseS3Error(res *http.Response) error {
var se *s3Error
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
return err
}
return se
}
// TODO: replace below by using upload pkg after we have rangefunc; until
// then, we need to keep this free of rangefunc for now.
type chunkRange[I constraints.Integer] struct {
// Start is the byte offset of the chunk.
Start I
// End is the byte offset of the last byte in the chunk.
End I
}
func (c chunkRange[I]) Size() I {
return c.End - c.Start + 1
}
func (c chunkRange[I]) String() string {
return fmt.Sprintf("%d-%d", c.Start, c.End)
}
func (c chunkRange[I]) LogValue() slog.Value {
return slog.StringValue(c.String())
}
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
// not a multiple of chunkSize.
//
// The first part number is 1 and increases monotonically.
func chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, chunkRange[I]] {
return func(yield func(int, chunkRange[I]) bool) {
var n int
for off := I(0); off < size; off += chunkSize {
n++
if !yield(n, chunkRange[I]{
Start: off,
End: off + min(chunkSize, size-off) - 1,
}) {
return
}
}
}
}
func redactAmzSignature(s string) string {
u, err := url.Parse(s)
if err != nil {
return ""
}
q := u.Query()
q.Set("X-Amz-Signature", "REDACTED")
u.RawQuery = q.Encode()
return u.String()
}

View File

@@ -32,13 +32,10 @@ import (
"golang.org/x/term"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -360,47 +357,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
return generateInteractive(cmd, opts)
}
func errFromUnknownKey(unknownKeyErr error) error {
// find SSH public key in the error message
sshKeyPattern := `ssh-\w+ [^\s"]+`
re := regexp.MustCompile(sshKeyPattern)
matches := re.FindStringSubmatch(unknownKeyErr.Error())
if len(matches) > 0 {
serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey()
if err != nil {
return unknownKeyErr
}
if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
if err != nil {
return unknownKeyErr
}
localPubKey = strings.TrimSpace(string(svcPubKey))
}
// check if the returned public key matches the local public key, this prevents adding a remote key to the user's account
if serverPubKey != localPubKey {
return unknownKeyErr
}
var msg strings.Builder
msg.WriteString(unknownKeyErr.Error())
msg.WriteString("\n\nYour ollama key is:\n")
msg.WriteString(localPubKey)
msg.WriteString("\nAdd your key at:\n")
msg.WriteString("https://ollama.com/settings/keys")
return errors.New(msg.String())
}
return unknownKeyErr
}
func PushHandler(cmd *cobra.Command, args []string) error {
client, err := api.ClientFromEnvironment()
if err != nil {
@@ -448,20 +404,6 @@ func PushHandler(cmd *cobra.Command, args []string) error {
request := api.PushRequest{Name: args[0], Insecure: insecure}
if err := client.Push(cmd.Context(), &request, fn); err != nil {
if spinner != nil {
spinner.Stop()
}
if strings.Contains(err.Error(), "access denied") {
return errors.New("you are not authorized to push to this namespace, create the model under a namespace you own")
}
host := model.ParseName(args[0]).Host
isOllamaHost := strings.HasSuffix(host, ".ollama.ai") || strings.HasSuffix(host, ".ollama.com")
if strings.Contains(err.Error(), errtypes.UnknownOllamaKeyErrMsg) && isOllamaHost {
// the user has not added their ollama key to ollama.com
// re-throw an error with a more user-friendly message
return errFromUnknownKey(err)
}
return err
}
@@ -889,17 +831,19 @@ func generate(cmd *cobra.Command, opts runOptions) error {
}
func RunServer(cmd *cobra.Command, _ []string) error {
// retrieve the OLLAMA_HOST environment variable
ollamaHost, err := api.GetOllamaHost()
host, port, err := net.SplitHostPort(strings.Trim(os.Getenv("OLLAMA_HOST"), "\"'"))
if err != nil {
return err
host, port = "127.0.0.1", "11434"
if ip := net.ParseIP(strings.Trim(os.Getenv("OLLAMA_HOST"), "[]")); ip != nil {
host = ip.String()
}
}
if err := initializeKeypair(); err != nil {
return err
}
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
ln, err := net.Listen("tcp", net.JoinHostPort(host, port))
if err != nil {
return err
}

View File

@@ -1010,7 +1010,8 @@ Generate embeddings from a model
### Parameters
- `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for
- `prompt`: string to generate the embedding for
- `prompts`: array of strings to generate a batch of embeddings for
Advanced parameters:
@@ -1038,3 +1039,33 @@ curl http://localhost:11434/api/embeddings -d '{
]
}
```
#### Request (batch)
```shell
curl http://localhost:11434/api/embeddings -d '{
"model": "all-minilm",
"prompt_batch": [
"Here is an article about llamas...",
"Here is another article about llamas..."
]
}'
```
#### Response
```json
{
"embedding_batch": [
[
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
],
[
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
],
]
}
```

View File

@@ -17,12 +17,10 @@ Let's start by asking a simple question that we can get an answer to from the **
Then we can create a model and ask the question:
```python
from langchain_community.llms import Ollama
ollama = Ollama(
base_url='http://localhost:11434',
model="llama3"
)
print(ollama.invoke("why is the sky blue"))
from langchain.llms import Ollama
ollama = Ollama(base_url='http://localhost:11434',
model="llama2")
print(ollama("why is the sky blue"))
```
Notice that we are defining the model and the base URL for Ollama.

View File

@@ -32,25 +32,9 @@ func PayloadsDir() (string, error) {
slog.Error("failed to lookup executable path", "error", err)
return "", err
}
cwd, err := os.Getwd()
if err != nil {
slog.Error("failed to lookup working directory", "error", err)
return "", err
}
var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths,
filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
)
}
// Try a few variations to improve developer experience when building from source in the local tree
for _, p := range paths {
candidate := filepath.Join(p, "ollama_runners")
for _, d := range []string{".", "windows-" + runtime.GOARCH, "dist\\windows-" + runtime.GOARCH} {
candidate := filepath.Join(filepath.Dir(appExe), d, "ollama_runners")
_, err := os.Stat(candidate)
if err == nil {
runnersDir = candidate

View File

@@ -0,0 +1,64 @@
//go:build integration
package integration
import (
"context"
"net/http"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestAllMiniLMEmbedding(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
if len(res.Embedding) != 384 {
t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding))
}
if res.Embedding[0] != 0.146763876080513 {
t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0])
}
}
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompts: []string{"why is the sky blue?", "why is the sky blue?"},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
if len(res.Embeddings) != 2 {
t.Fatal("Expected 2 embeddings to be returned")
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0]))
}
if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 {
t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0])
}
}

View File

@@ -5,6 +5,7 @@ package integration
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
@@ -24,6 +25,7 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
// Generate a set of requests
// By default each request uses orca-mini as the model
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
stream := false
return []api.GenerateRequest{
{
Model: "orca-mini",
@@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
}
}
func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse {
// TODO maybe stuff in an init routine?
lifecycle.InitLogging()
requestJSON, err := json.Marshal(req)
if err != nil {
t.Fatalf("Error serializing request: %v", err)
}
defer func() {
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
fp, err := os.Open(lifecycle.ServerLogFile)
if err != nil {
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
return
}
slog.Warn("SERVER LOG FOLLOWS")
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err = os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}()
scheme, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
assert.NoError(t, StartServer(ctx, testEndpoint))
}
err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model)
if err != nil {
t.Fatalf("Error pulling model: %v", err)
}
// Make the request and get the response
httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON))
if err != nil {
t.Fatalf("Error creating request: %v", err)
}
// Set the content type for the request
httpReq.Header.Set("Content-Type", "application/json")
// Make the request with the HTTP client
response, err := client.Do(httpReq.WithContext(ctx))
if err != nil {
t.Fatalf("Error making request: %v", err)
}
defer response.Body.Close()
body, err := io.ReadAll(response.Body)
assert.NoError(t, err)
assert.Equal(t, response.StatusCode, 200, string(body))
// Verify the response is valid JSON
var res api.EmbeddingResponse
err = json.Unmarshal(body, &res)
if err != nil {
assert.NoError(t, err, body)
}
return res
}

View File

@@ -1032,7 +1032,7 @@ struct llama_server_context
slot.has_next_token = false;
}
if (!slot.cache_tokens.empty() && llama_token_is_eog(model, result.tok))
if (!slot.cache_tokens.empty() && result.tok == llama_token_eos(model))
{
slot.stopped_eos = true;
slot.has_next_token = false;
@@ -1144,15 +1144,12 @@ struct llama_server_context
res.result_json = json
{
{"content", tkn.text_to_send},
{"stop", false},
{"slot_id", slot.id},
{"multimodal", multimodal}
};
if (!llama_token_is_eog(model, tkn.tok)) {
res.result_json["content"] = tkn.text_to_send;
}
if (slot.sparams.n_probs > 0)
{
std::vector<completion_token_output> probs_output = {};
@@ -2647,18 +2644,18 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
if (strncmp(sep, "int:", 4) == 0) {
sep += 4;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_INT;
kvo.val_i64 = std::atol(sep);
kvo.int_value = std::atol(sep);
} else if (strncmp(sep, "float:", 6) == 0) {
sep += 6;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_FLOAT;
kvo.val_f64 = std::atof(sep);
kvo.float_value = std::atof(sep);
} else if (strncmp(sep, "bool:", 5) == 0) {
sep += 5;
kvo.tag = LLAMA_KV_OVERRIDE_TYPE_BOOL;
if (std::strcmp(sep, "true") == 0) {
kvo.val_bool = true;
kvo.bool_value = true;
} else if (std::strcmp(sep, "false") == 0) {
kvo.val_bool = false;
kvo.bool_value = false;
} else {
fprintf(stderr, "error: Invalid boolean value for KV override: %s\n", argv[i]);
invalid_param = true;
@@ -3212,54 +3209,27 @@ int main(int argc, char **argv) {
return res.set_content(data.dump(), "application/json; charset=utf-8");
});
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
{
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
const json body = json::parse(req.body);
json prompt;
if (body.count("content") != 0)
{
prompt = body["content"];
}
else
{
prompt = "";
const int id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id);
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
task_result recv = llama.queue_results.recv(id);
llama.queue_results.remove_waiting_task_id(id);
json embeddings = json::array();
for (auto & elem : recv.result_json["results"]) {
embeddings.push_back(json_value(elem, "embedding", json::array()));
}
json image_data;
if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
}
// create and queue the task
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
json result = json{{"embeddings", embeddings}};
return res.set_content(result.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
// "Bus error: 10" - this is on macOS, it does not crash on Linux
//std::thread t2([&]()
/*{
bool running = true;
while (running)
{
running = llama.update_slots();
}
}*/
//);
if (sparams.n_threads_http < 1) {
// +2 threads for monitoring endpoints
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);

View File

@@ -42,9 +42,8 @@ function init_vars {
"-DLLAMA_NATIVE=off"
)
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
$script:ARCH = $Env:PROCESSOR_ARCHITECTURE.ToLower()
$script:ARCH = "amd64" # arm not yet supported.
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
md "$script:DIST_BASE" -ea 0 > $null
if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
$script:config = "RelWithDebInfo"
@@ -182,7 +181,7 @@ function cleanup {
function build_static() {
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
@@ -213,11 +212,11 @@ function build_static() {
}
}
function build_cpu($gen_arch) {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
function build_cpu() {
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
# remaining llama.cpp builds use MSVC
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", $gen_arch, "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu"
$script:distDir="$script:DIST_BASE\cpu"
write-host "Building LCD CPU"
@@ -230,7 +229,7 @@ function build_cpu($gen_arch) {
}
function build_cpu_avx() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
@@ -240,12 +239,12 @@ function build_cpu_avx() {
sign
install
} else {
write-host "Skipping CPU AVX generation step as requested"
write-host "Skipping CPU generation step as requested"
}
}
function build_cpu_avx2() {
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
init_vars
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
@@ -255,12 +254,12 @@ function build_cpu_avx2() {
sign
install
} else {
write-host "Skipping CPU AVX2 generation step as requested"
write-host "Skipping CPU generation step as requested"
}
}
function build_cuda() {
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
if ($null -ne $script:CUDA_LIB_DIR) {
# Then build cuda as a dynamically loaded library
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
@@ -284,13 +283,11 @@ function build_cuda() {
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
} else {
write-host "Skipping CUDA generation step"
}
}
function build_rocm() {
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
if ($null -ne $env:HIP_PATH) {
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
if ($null -ne $script:ROCM_VERSION) {
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
@@ -339,8 +336,6 @@ function build_rocm() {
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
} else {
write-host "Skipping ROCm generation step"
}
}
@@ -349,15 +344,11 @@ if ($($args.count) -eq 0) {
git_module_setup
apply_patches
build_static
if ($script:ARCH -eq "arm64") {
build_cpu("ARM64")
} else { # amd64
build_cpu("x64")
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
}
build_cpu
build_cpu_avx
build_cpu_avx2
build_cuda
build_rocm
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path $script:DIST_BASE)"

View File

@@ -4,7 +4,6 @@ package llm
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
// #cgo windows,arm64 LDFLAGS: ${SRCDIR}/build/windows/arm64_static/libllama.a -static -lstdc++
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
// #include <stdlib.h>

View File

@@ -32,7 +32,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embedding(ctx context.Context, prompt string) ([]float64, error)
Embeddings(ctx context.Context, prompt []string) ([][]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -73,7 +73,8 @@ func LoadModel(model string) (*GGML, error) {
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
var err error
if opts.NumCtx > int(ggml.KV().ContextLength()) {
slog.Warn("requested context length is greater than the model's training context window size", "requested", opts.NumCtx, "training size", ggml.KV().ContextLength())
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
opts.NumCtx = int(ggml.KV().ContextLength())
}
if opts.NumCtx < 4 {
@@ -735,15 +736,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("max retries exceeded")
}
type EmbeddingRequest struct {
Content string `json:"content"`
type EmbeddingsRequest struct {
Contents []string `json:"contents"`
}
type EmbeddingResponse struct {
Embedding []float64 `json:"embedding"`
type EmbeddingsResponse struct {
Embeddings [][]float64 `json:"embeddings"`
}
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *llmServer) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
@@ -757,12 +758,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(TokenizeRequest{Content: prompt})
data, err := json.Marshal(EmbeddingsRequest{Contents: prompts})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embeddings", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
@@ -779,17 +780,19 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
return nil, fmt.Errorf("error reading embed response: %w", err)
}
fmt.Println("embeddings response", string(body))
if resp.StatusCode >= 400 {
log.Printf("llm encode error: %s", body)
return nil, fmt.Errorf("%s", body)
}
var embedding EmbeddingResponse
var embedding EmbeddingsResponse
if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return embedding.Embedding, nil
return embedding.Embeddings, nil
}
type TokenizeRequest struct {

View File

@@ -19,7 +19,7 @@ export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama3'
const command = 'ollama run llama2'
return (
<div className='drag'>

View File

@@ -7,8 +7,6 @@
$ErrorActionPreference = "Stop"
function checkEnv() {
$script:TARGET_ARCH=$Env:PROCESSOR_ARCHITECTURE.ToLower()
Write-host "Building for ${script:TARGET_ARCH}"
write-host "Locating required tools and paths"
$script:SRC_DIR=$PWD
if (!$env:VCToolsRedistDir) {
@@ -32,7 +30,7 @@ function checkEnv() {
$script:INNO_SETUP_DIR=(get-item "C:\Program Files*\Inno Setup*\")[0]
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}"
$script:DEPS_DIR="${script:SRC_DIR}\dist\windows-amd64"
$env:CGO_ENABLED="1"
echo "Checking version"
if (!$env:VERSION) {
@@ -83,8 +81,8 @@ function buildOllama() {
/csp "Google Cloud KMS Provider" /kc ${env:KEY_CONTAINER} ollama.exe
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
New-Item -ItemType Directory -Path .\dist\windows-${script:TARGET_ARCH}\ -Force
cp .\ollama.exe .\dist\windows-${script:TARGET_ARCH}\
New-Item -ItemType Directory -Path .\dist\windows-amd64\ -Force
cp .\ollama.exe .\dist\windows-amd64\
}
function buildApp() {
@@ -129,16 +127,16 @@ function buildInstaller() {
cd "${script:SRC_DIR}\app"
$env:PKG_VERSION=$script:PKG_VERSION
if ("${env:KEY_CONTAINER}") {
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" /SMySignTool="${script:SignTool} sign /fd sha256 /t http://timestamp.digicert.com /f ${script:OLLAMA_CERT} /csp `$qGoogle Cloud KMS Provider`$q /kc ${env:KEY_CONTAINER} `$f" .\ollama.iss
} else {
& "${script:INNO_SETUP_DIR}\ISCC.exe" /DARCH=$script:TARGET_ARCH .\ollama.iss
& "${script:INNO_SETUP_DIR}\ISCC.exe" .\ollama.iss
}
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
}
function distZip() {
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-${script:TARGET_ARCH}\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-${script:TARGET_ARCH}.zip" -Force
write-host "Generating stand-alone distribution zip file ${script:SRC_DIR}\dist\ollama-windows-amd64.zip"
Compress-Archive -Path "${script:SRC_DIR}\dist\windows-amd64\*" -DestinationPath "${script:SRC_DIR}\dist\ollama-windows-amd64.zip" -Force
}
try {

View File

@@ -1,75 +0,0 @@
package server
import (
"cmp"
"fmt"
"os"
"path/filepath"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/types/model"
)
// cache is a simple demo disk cache. it does not validate anything
type cache struct {
dir string
}
func defaultCache() registry.Cache {
homeDir, _ := os.UserHomeDir()
if homeDir == "" {
panic("could not determine home directory")
}
modelsDir := cmp.Or(
os.Getenv("OLLAMA_MODELS"),
filepath.Join(homeDir, ".ollama", "models"),
)
return &cache{modelsDir}
}
func invalidDigest(digest string) error {
return fmt.Errorf("invalid digest: %s", digest)
}
func (c *cache) OpenLayer(d model.Digest) (registry.ReadAtSeekCloser, error) {
return os.Open(c.LayerFile(d))
}
func (c *cache) LayerFile(d model.Digest) string {
return filepath.Join(c.dir, "blobs", d.String())
}
func (c *cache) PutLayerFile(d model.Digest, fromPath string) error {
if !d.IsValid() {
return invalidDigest(d.String())
}
bfile := c.LayerFile(d)
dir, _ := filepath.Split(bfile)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.Rename(fromPath, bfile)
}
func (c *cache) ManifestData(name model.Name) []byte {
if !name.IsFullyQualified() {
return nil
}
data, err := os.ReadFile(filepath.Join(c.dir, "manifests", name.Filepath()))
if err != nil {
return nil
}
return data
}
func (c *cache) SetManifestData(name model.Name, data []byte) error {
if !name.IsFullyQualified() {
return fmt.Errorf("invalid name: %s", name)
}
filep := filepath.Join(c.dir, "manifests", name.Filepath())
dir, _ := filepath.Split(filep)
if err := os.MkdirAll(dir, 0755); err != nil {
return err
}
return os.WriteFile(filep, data, 0644)
}

View File

@@ -5,7 +5,6 @@ import (
"bytes"
"context"
"crypto/sha256"
"encoding/base64"
"encoding/hex"
"encoding/json"
"errors"
@@ -26,12 +25,10 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/convert"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -713,10 +710,6 @@ func CopyModel(src, dst model.Name) error {
return model.Unqualified(src)
}
if src.Filepath() == dst.Filepath() {
return nil
}
manifests, err := GetManifestPath()
if err != nil {
return err
@@ -983,6 +976,9 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
}
return err
}
}
@@ -1145,40 +1141,9 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
return fmt.Sprintf("sha256:%x", h.Sum(nil)), n
}
var errUnauthorized = fmt.Errorf("unauthorized: access denied")
// getTokenSubject returns the subject of a JWT token, it does not validate the token
func getTokenSubject(token string) string {
parts := strings.Split(token, ".")
if len(parts) != 3 {
slog.Error("jwt token does not contain 3 parts")
return ""
}
payload := parts[1]
payloadBytes, err := base64.RawURLEncoding.DecodeString(payload)
if err != nil {
slog.Error(fmt.Sprintf("failed to decode jwt payload: %v", err))
return ""
}
var payloadMap map[string]interface{}
if err := json.Unmarshal(payloadBytes, &payloadMap); err != nil {
slog.Error(fmt.Sprintf("failed to unmarshal payload JSON: %v", err))
return ""
}
sub, ok := payloadMap["sub"]
if !ok {
slog.Error("jwt does not contain 'sub' field")
return ""
}
return fmt.Sprintf("%s", sub)
}
var errUnauthorized = errors.New("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
@@ -1197,7 +1162,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if err != nil {
return nil, err
}
anonymous = getTokenSubject(token) == "anonymous"
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
@@ -1218,16 +1182,6 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
}
}
if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey()
if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized
}
return nil, &errtypes.UnknownOllamaKey{Key: pubKey}
}
// user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized
}

View File

@@ -1,7 +1,6 @@
package server
import (
"cmp"
"context"
"encoding/json"
"errors"
@@ -18,7 +17,6 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"syscall"
"time"
@@ -27,8 +25,6 @@ import (
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/client/ollama"
"github.com/ollama/ollama/client/registry"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/openai"
@@ -37,23 +33,6 @@ import (
"github.com/ollama/ollama/version"
)
// envs
var (
envRegistryBaseURL = cmp.Or(os.Getenv("OLLAMA_REGISTRY_BASE_URL"), "https://bllamo.com")
)
func init() {
ollama.I_Acknowledge_This_API_Is_Unstable = true
}
var experiments = sync.OnceValue(func() []string {
return strings.Split(strings.ToLower(os.Getenv("OLLAMA_EXPERIMENT")), ",")
})
func useExperiment(flag string) bool {
return slices.Contains(experiments(), flag)
}
var mode string = gin.DebugMode
type Server struct {
@@ -424,23 +403,39 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
// an empty request loads the model
if req.Prompt == "" {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
return
}
switch {
// single embedding
case len(req.Prompt) > 0:
embeddings, err := runner.llama.Embeddings(c.Request.Context(), []string{req.Prompt})
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{Embedding: embeddings[0]}
c.JSON(http.StatusOK, resp)
resp := api.EmbeddingResponse{
Embedding: embedding,
// batch embeddings
case len(req.PromptBatch) > 0:
embeddings, err := runner.llama.Embeddings(c.Request.Context(), req.PromptBatch)
if err != nil {
slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
c.JSON(http.StatusOK, resp)
// empty prompt loads the model
default:
if req.PromptBatch != nil {
c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
} else {
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
}
}
c.JSON(http.StatusOK, resp)
}
func (s *Server) PullModelHandler(c *gin.Context) {
@@ -465,25 +460,6 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return
}
if useExperiment("pull") {
rc := &registry.Client{
BaseURL: envRegistryBaseURL,
}
modelsDir, err := modelsDir()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
cache := &cache{dir: modelsDir}
println("DIR: ", modelsDir)
// TODO(bmizerany): progress updates
if err := rc.Pull(c.Request.Context(), cache, model); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
return
}
ch := make(chan any)
go func() {
defer close(ch)

View File

@@ -149,14 +149,6 @@ func (s *Scheduler) processPending(ctx context.Context) {
break
}
// If we're CPU only mode, just limit by loadedMax above
// TODO handle system memory exhaustion
if (len(gpus) == 1 && gpus[0].Library == "cpu") || pending.opts.NumGPU == 0 {
slog.Debug("cpu mode with existing models, loading")
s.loadFn(pending, ggml, gpus)
break
}
// No models loaded. Load the model but prefer the best fit.
if loadedCount == 0 {
slog.Debug("loading first model", "model", pending.model.ModelPath)

View File

@@ -28,33 +28,19 @@ func TestInitScheduler(t *testing.T) {
ctx, done := context.WithCancel(context.Background())
defer done()
initialMax := loadedMax
initialParallel := numParallel
s := InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "blue")
s = InitScheduler(ctx)
require.Equal(t, initialMax, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
s = InitScheduler(ctx)
require.Equal(t, 0, loadedMax)
s.loadedMu.Lock()
require.NotNil(t, s.loaded)
s.loadedMu.Unlock()
os.Setenv("OLLAMA_NUM_PARALLEL", "blue")
_ = InitScheduler(ctx)
require.Equal(t, initialParallel, numParallel)
os.Setenv("OLLAMA_NUM_PARALLEL", "10")
_ = InitScheduler(ctx)
require.Equal(t, 10, numParallel)
}
func TestLoad(t *testing.T) {
@@ -65,7 +51,6 @@ func TestLoad(t *testing.T) {
req := &LlmRequest{
ctx: ctx,
model: &Model{ModelPath: "foo"},
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
sessionDuration: 2,
@@ -78,9 +63,7 @@ func TestLoad(t *testing.T) {
s.load(req, ggml, gpus)
require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible")
@@ -95,9 +78,7 @@ func TestLoad(t *testing.T) {
case resp := <-req.successCh:
require.Equal(t, uint64(10), resp.estimatedVRAM)
require.Equal(t, uint(1), resp.refCount)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
}
req.model.ModelPath = "dummy_model_path"
@@ -109,9 +90,7 @@ func TestLoad(t *testing.T) {
case resp := <-req.successCh:
t.Errorf("unexpected success %v", resp)
}
s.loadedMu.Lock()
runner := s.loaded["dummy_model_path"]
s.loadedMu.Unlock()
require.NotNil(t, runner)
require.Equal(t, uint(0), runner.refCount)
time.Sleep(1 * time.Millisecond)
@@ -164,7 +143,6 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
scenario.req = &LlmRequest{
ctx: scenario.ctx,
model: model,
opts: api.DefaultOptions(),
sessionDuration: 5 * time.Millisecond,
successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1),
@@ -193,9 +171,7 @@ func TestRequests(t *testing.T) {
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
scenario3c := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
@@ -264,9 +240,7 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
loadedMax = 0
s.newServerFn = scenario3b.newServer
@@ -280,14 +254,19 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load
// Try to load a model that wont fit
s.newServerFn = scenario3c.newServer
slog.Info("scenario3c")
require.Len(t, s.loaded, 2)
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3c.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
require.Len(t, s.loaded, 1)
scenario3b.ctxDone()
select {
case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv)
@@ -296,36 +275,7 @@ func TestRequests(t *testing.T) {
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
// Try to load a model that wont fit
s.newServerFn = scenario3d.newServer
slog.Info("scenario3d")
s.loadedMu.Lock()
require.Len(t, s.loaded, 3)
s.loadedMu.Unlock()
scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- scenario3d.req
// finish prior request, so new model can load
time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
scenario3b.ctxDone()
select {
case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, scenario3d.req.errCh, 0)
case <-ctx.Done():
t.Errorf("timeout")
}
s.loadedMu.Lock()
require.Len(t, s.loaded, 2)
s.loadedMu.Unlock()
require.Len(t, s.loaded, 1)
}
func TestGetRunner(t *testing.T) {
@@ -368,9 +318,7 @@ func TestGetRunner(t *testing.T) {
t.Errorf("timeout")
}
scenario1a.ctxDone()
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
scenario1c.req.model.ModelPath = "bad path"
slog.Info("scenario1c")
@@ -380,9 +328,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 0)
time.Sleep(5 * time.Millisecond)
s.loadedMu.Lock()
require.Len(t, s.loaded, 0)
s.loadedMu.Unlock()
require.Len(t, errCh1c, 1)
err = <-errCh1c
require.Contains(t, err.Error(), "bad path")
@@ -412,9 +358,7 @@ func TestPrematureExpired(t *testing.T) {
require.Equal(t, resp.llama, scenario1a.srv)
require.Len(t, s.pendingReqCh, 0)
require.Len(t, errCh1a, 0)
s.loadedMu.Lock()
require.Len(t, s.loaded, 1)
s.loadedMu.Unlock()
slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case <-ctx.Done():
@@ -439,7 +383,6 @@ func TestUseLoadedRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
successCh: make(chan *runnerRef, 1),
sessionDuration: 2,
}
@@ -483,10 +426,8 @@ func TestUpdateFreeSpace(t *testing.T) {
r2 := &runnerRef{llama: llm2, gpus: gpus}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.updateFreeSpace(gpus)
require.Equal(t, uint64(850), gpus[0].FreeMemory)
@@ -496,18 +437,13 @@ func TestUpdateFreeSpace(t *testing.T) {
func TestFindRunnerToUnload(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 5*time.Millisecond)
defer done()
req := &LlmRequest{
ctx: ctx,
opts: api.DefaultOptions(),
}
req := &LlmRequest{ctx: ctx}
r1 := &runnerRef{refCount: 1, sessionDuration: 1}
r2 := &runnerRef{sessionDuration: 2}
s := InitScheduler(ctx)
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
resp := s.findRunnerToUnload(req)
require.Equal(t, r2, resp)
@@ -522,11 +458,10 @@ func TestNeedsReload(t *testing.T) {
defer done()
llm := &mockLlm{}
do := api.DefaultOptions()
runner := &runnerRef{
adapters: []string{"adapter1"},
projectors: []string{"projector1"},
Options: &do,
Options: &api.Options{},
llama: llm,
}
req := &LlmRequest{
@@ -534,7 +469,7 @@ func TestNeedsReload(t *testing.T) {
AdapterPaths: []string{"adapter2"},
ProjectorPaths: []string{"projector2"},
},
opts: api.DefaultOptions(),
opts: api.Options{},
}
resp := runner.needsReload(ctx, req)
require.True(t, resp)
@@ -573,10 +508,8 @@ func TestUnloadAllRunners(t *testing.T) {
r1 := &runnerRef{llama: llm1}
r2 := &runnerRef{llama: llm2}
s.loadedMu.Lock()
s.loaded["a"] = r1
s.loaded["b"] = r2
s.loadedMu.Unlock()
s.unloadAllRunners()
require.True(t, llm1.closeCalled)
@@ -597,7 +530,7 @@ type mockLlm struct {
pingResp error
waitResp error
completionResp error
embeddingResp []float64
embeddingResp [][]float64
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
@@ -613,7 +546,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp
}
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
func (s *mockLlm) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {

View File

@@ -1,18 +0,0 @@
// Package errtypes contains custom error types
package errtypes
import (
"fmt"
"strings"
)
const UnknownOllamaKeyErrMsg = "unknown ollama key"
// TODO: This should have a structured response from the API
type UnknownOllamaKey struct {
Key string
}
func (e *UnknownOllamaKey) Error() string {
return fmt.Sprintf("unauthorized: %s %q", UnknownOllamaKeyErrMsg, strings.TrimSpace(e.Key))
}

View File

@@ -81,6 +81,9 @@ func (k partKind) String() string {
//
// It is not guaranteed to be valid. Use [Name.IsValid] to check if the name
// is valid.
//
// It is not directly comparable with other Names. Use [Name.Equal] and
// [Name.MapHash] for determining equality and using as a map key.
type Name struct {
Host string
Namespace string
@@ -107,20 +110,20 @@ type Name struct {
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." | ":" }*
// pattern: alphanum { alphanum | "-" | "_" | "." | ":" }*
// length: [1, 350]
// namespace:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" }*
// length: [1, 80]
// pattern: alphanum { alphanum | "-" | "_" }*
// length: [2, 80]
// model:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// length: [2, 80]
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// pattern: alphanum { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
// pattern: alphanum { alphanum | "-" | ":" }*
// length: [2, 80]
//
// Most users should use [ParseName] instead, unless need to support
// different defaults than DefaultName.
@@ -168,6 +171,11 @@ func Merge(a, b Name) Name {
return a
}
// Digest returns the result of [ParseDigest] with the RawDigest field.
func (n Name) Digest() Digest {
return ParseDigest(n.RawDigest)
}
// String returns the name string, in the format that [ParseNameNoDefaults]
// accepts as valid, if [Name.IsValid] reports true; otherwise the empty
// string is returned.
@@ -196,7 +204,7 @@ func (n Name) String() string {
// IsValid reports whether all parts of the name are present and valid. The
// digest is a special case, and is checked for validity only if present.
func (n Name) IsValid() bool {
if n.RawDigest != "" && !isValidPart(kindDigest, n.RawDigest) {
if n.RawDigest != "" && !ParseDigest(n.RawDigest).IsValid() {
return false
}
return n.IsFullyQualified()
@@ -232,12 +240,12 @@ func (n Name) Filepath() string {
if !n.IsFullyQualified() {
panic("illegal attempt to get filepath of invalid name")
}
return strings.ToLower(filepath.Join(
n.Host,
n.Namespace,
n.Model,
n.Tag,
))
return filepath.Join(
strings.ToLower(n.Host),
strings.ToLower(n.Namespace),
strings.ToLower(n.Model),
strings.ToLower(n.Tag),
)
}
// LogValue returns a slog.Value that represents the name as a string.
@@ -252,7 +260,7 @@ func isValidLen(kind partKind, s string) bool {
case kindTag:
return len(s) >= 1 && len(s) <= 80
default:
return len(s) >= 1 && len(s) <= 80
return len(s) >= 2 && len(s) <= 80
}
}
@@ -262,7 +270,7 @@ func isValidPart(kind partKind, s string) bool {
}
for i := range s {
if i == 0 {
if !isAlphanumericOrUnderscore(s[i]) {
if !isAlphanumeric(s[i]) {
return false
}
continue
@@ -274,11 +282,11 @@ func isValidPart(kind partKind, s string) bool {
return false
}
case ':':
if kind != kindHost && kind != kindDigest {
if kind != kindHost {
return false
}
default:
if !isAlphanumericOrUnderscore(s[i]) {
if !isAlphanumeric(s[i]) {
return false
}
}
@@ -286,8 +294,8 @@ func isValidPart(kind partKind, s string) bool {
return true
}
func isAlphanumericOrUnderscore(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
func isAlphanumeric(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9'
}
func cutLast(s, sep string) (before, after string, ok bool) {
@@ -310,7 +318,7 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
}
type DigestType byte
type DigestType int
const (
DigestTypeInvalid DigestType = iota
@@ -318,48 +326,66 @@ const (
)
func (t DigestType) String() string {
switch t {
case DigestTypeSHA256:
if t == DigestTypeSHA256 {
return "sha256"
default:
return "invalid"
}
return "unknown"
}
// Digest represents a type and hash of a digest. It is comparable and can
// be used as a map key.
type Digest struct {
Type DigestType
Sum [32]byte
Hash [32]byte
}
func ParseDigest(s string) (Digest, error) {
i := strings.IndexAny(s, "-:")
if i < 0 {
return Digest{}, fmt.Errorf("invalid digest %q", s)
// ParseDigest parses a digest string into a Digest struct. It accepts both
// the forms:
//
// sha256:deadbeef
// sha256-deadbeef
//
// The hash part must be exactly 64 characters long.
//
// The form "type:hash" does not round trip through [Digest.String].
func ParseDigest(s string) Digest {
typ, hash, ok := cutLast(s, ":")
if !ok {
typ, hash, ok = cutLast(s, "-")
if !ok {
return Digest{}
}
}
typ, encSum := s[:i], s[i+1:]
if typ != "sha256" {
return Digest{}, fmt.Errorf("unsupported digest type %q", typ)
return Digest{}
}
d := Digest{
Type: DigestTypeSHA256,
var d Digest
n, err := hex.Decode(d.Hash[:], []byte(hash))
if err != nil || n != 32 {
return Digest{}
}
n, err := hex.Decode(d.Sum[:], []byte(encSum))
if err != nil {
return Digest{}, err
}
if n != 32 {
return Digest{}, fmt.Errorf("digest %q decoded to %d bytes; want 32", encSum, n)
}
return d, nil
}
func (d Digest) String() string {
if d.Type == DigestTypeInvalid {
return ""
}
return fmt.Sprintf("sha256-%x", d.Sum)
return Digest{Type: DigestTypeSHA256, Hash: d.Hash}
}
// IsValid returns true if the digest has a valid Type and Hash.
func (d Digest) IsValid() bool {
return d.Type != DigestTypeInvalid
if d.Type != DigestTypeSHA256 {
return false
}
return d.Hash != [32]byte{}
}
// String returns the digest as a string in the form "type-hash". The hash
// is encoded as a hex string.
func (d Digest) String() string {
var b strings.Builder
b.WriteString(d.Type.String())
b.WriteByte('-')
b.WriteString(hex.EncodeToString(d.Hash[:]))
return b.String()
}
// LogValue returns a slog.Value that represents the digest as a string.
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}

View File

@@ -2,7 +2,7 @@ package model
import (
"reflect"
"runtime"
"strings"
"testing"
)
@@ -82,10 +82,10 @@ func TestParseNameParts(t *testing.T) {
wantValidDigest: false,
},
{
in: "model@sha256:123",
in: "model@sha256:" + validSHA256Hex,
want: Name{
Model: "model",
RawDigest: "sha256:123",
RawDigest: "sha256:" + validSHA256Hex,
},
wantValidDigest: true,
},
@@ -97,18 +97,14 @@ func TestParseNameParts(t *testing.T) {
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
}
if got.Digest().IsValid() != tt.wantValidDigest {
t.Errorf("parseName(%q).Digest().IsValid() = %v; want %v", tt.in, got.Digest().IsValid(), tt.wantValidDigest)
}
})
}
}
var testCases = map[string]bool{ // name -> valid
"": false,
"_why/_the/_lucky:_stiff": true,
// minimal
"h/n/m:t@d": true,
"host/namespace/model:tag": true,
"host/namespace/model": false,
"namespace/model": false,
@@ -124,12 +120,11 @@ var testCases = map[string]bool{ // name -> valid
"h/nn/mm:t@sha256-1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
"h/nn/mm:t@sha256:1000000000000000000000000000000000000000000000000000000000000000": true, // bare minimum part sizes
// unqualified
"m": false,
"n/m:": false,
"h/n/m": false,
"@t": false,
"m@d": false,
"m": false, // model too short
"n/mm:": false, // namespace too short
"h/n/mm:t": false, // namespace too short
"@t": false, // digest too short
"mm@d": false, // digest too short
// invalids
"^": false,
@@ -149,6 +144,8 @@ var testCases = map[string]bool{ // name -> valid
"hh/nn/mm:-tt@dd": false,
"hh/nn/mm:tt@-dd": false,
"": false,
// hosts
"host:https/namespace/model:tag": true,
@@ -170,6 +167,7 @@ func TestNameIsValid(t *testing.T) {
var numStringTests int
for s, want := range testCases {
n := ParseNameBare(s)
t.Logf("n: %#v", n)
got := n.IsValid()
if got != want {
t.Errorf("parseName(%q).IsValid() = %v; want %v", s, got, want)
@@ -218,54 +216,6 @@ func TestNameIsValidPart(t *testing.T) {
}
func TestFilepathAllocs(t *testing.T) {
n := ParseNameBare("HOST/NAMESPACE/MODEL:TAG")
allocs := testing.AllocsPerRun(1000, func() {
n.Filepath()
})
allowedAllocs := 2.0
if runtime.GOOS == "windows" {
allowedAllocs = 4
}
if allocs > allowedAllocs {
t.Errorf("allocs = %v; allowed %v", allocs, allowedAllocs)
}
}
const (
validSha256 = "sha256-1000000000000000000000000000000000000000000000000000000000000000"
validSha256Old = "sha256:1000000000000000000000000000000000000000000000000000000000000000"
)
func TestParseDigest(t *testing.T) {
cases := []struct {
in string
want string
}{
{"", ""}, // empty
{"sha123-12", ""}, // invalid type
{"sha256-", ""}, // invalid sum
{"sha256-123", ""}, // invalid odd length sum
{validSha256, validSha256},
{validSha256Old, validSha256},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
got, err := ParseDigest(tt.in)
if err != nil {
if tt.want != "" {
t.Errorf("parseDigest(%q) = %v; want %v", tt.in, err, tt.want)
}
return
}
if got.String() != tt.want {
t.Errorf("parseDigest(%q).String() = %q; want %q", tt.in, got, tt.want)
}
})
}
}
func FuzzName(f *testing.F) {
for s := range testCases {
f.Add(s)
@@ -289,3 +239,57 @@ func FuzzName(f *testing.F) {
})
}
const validSHA256Hex = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"
func TestParseDigest(t *testing.T) {
cases := map[string]bool{
"sha256-1000000000000000000000000000000000000000000000000000000000000000": true,
"sha256:1000000000000000000000000000000000000000000000000000000000000000": true,
"sha256:0000000000000000000000000000000000000000000000000000000000000000": false,
"sha256:" + validSHA256Hex: true,
"sha256-" + validSHA256Hex: true,
"": false,
"sha134:" + validSHA256Hex: false,
"sha256:" + validSHA256Hex + "x": false,
"sha256:x" + validSHA256Hex: false,
"sha256-" + validSHA256Hex + "x": false,
"sha256-x": false,
}
for s, want := range cases {
t.Run(s, func(t *testing.T) {
d := ParseDigest(s)
if d.IsValid() != want {
t.Errorf("ParseDigest(%q).IsValid() = %v; want %v", s, d.IsValid(), want)
}
norm := strings.ReplaceAll(s, ":", "-")
if d.IsValid() && d.String() != norm {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, d.String(), norm)
}
})
}
}
func TestDigestString(t *testing.T) {
cases := []struct {
in string
want string
}{
{in: "sha256:" + validSHA256Hex, want: "sha256-" + validSHA256Hex},
{in: "sha256-" + validSHA256Hex, want: "sha256-" + validSHA256Hex},
{in: "", want: "unknown-0000000000000000000000000000000000000000000000000000000000000000"},
{in: "blah-100000000000000000000000000000000000000000000000000000000000000", want: "unknown-0000000000000000000000000000000000000000000000000000000000000000"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
d := ParseDigest(tt.in)
if d.String() != tt.want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", tt.in, d.String(), tt.want)
}
})
}
}

15
types/structs/structs.go Normal file
View File

@@ -0,0 +1,15 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package structs contains the Incomparable type.
package structs
// Incomparable is a zero-width incomparable type. If added as the
// first field in a struct, it marks that struct as not comparable
// (can't do == or be a map key) and usually doesn't add any width to
// the struct (unless the struct has only small fields).
//
// By making a struct incomparable, you can prevent misuse (prevent
// people from using ==), but also you can shrink generated binaries,
// as the compiler can omit equality funcs from the binary.
type Incomparable [0]func()