Compare commits

..

1 Commits

Author SHA1 Message Date
ParthSareen
0de5bbd0fe sample: use json unmarshal for sampling params 2025-03-20 15:03:42 -04:00
56 changed files with 274 additions and 2349 deletions

View File

@@ -512,7 +512,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama for Zig](https://github.com/dravenk/ollama-zig)
- [Abso](https://github.com/lunary-ai/abso) (OpenAI-compatible TypeScript SDK for any LLM provider)
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
### Mobile

View File

@@ -1,178 +0,0 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -703,8 +703,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:

View File

@@ -87,8 +87,6 @@ func TestShowInfo(t *testing.T) {
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
@@ -113,8 +111,6 @@ func TestShowInfo(t *testing.T) {
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434

View File

@@ -558,10 +558,6 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,

View File

@@ -1,59 +0,0 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager --follow --pager-end
journalctl -u ollama --no-pager
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:

View File

@@ -1,22 +0,0 @@
//go:build go1.24
package grammar
import "testing"
func BenchmarkFromSchema(b *testing.B) {
for tt := range testCases(b) {
b.Run("", func(b *testing.B) {
s := []byte(tt.schema)
b.ReportAllocs()
for b.Loop() {
_, err := FromSchema(nil, s)
if err != nil {
b.Fatalf("GrammarFromSchema: %v", err)
}
}
})
return
}
}

View File

@@ -1,227 +0,0 @@
package grammar
import (
"bytes"
"encoding/json"
"fmt"
"iter"
"strconv"
"github.com/ollama/ollama/grammar/jsonschema"
)
const jsonTerms = `
# Unicode
#
# Unicode characters can be specified directly in the grammar, for example
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
# (\UXXXXXXXX).
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
# JSON grammar from RFC 7159
null ::= "null"
object ::= "{" (kv ("," kv)*)? "}"
array ::= "[" (value ("," value)*)? "]"
kv ::= string ":" value
integer ::= "0" | [1-9] [0-9]*
number ::= "-"? integer frac? exp?
frac ::= "." [0-9]+
exp ::= ("e" | "E") ("+" | "-") [0-9]+
string ::= "\"" char* "\""
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
char ::= [^"\\] | escape
space ::= (" " | "\t" | "\n" | "\r")*
hex ::= [0-9] | [a-f] | [A-F]
boolean ::= "true" | "false"
value ::= object | array | string | number | boolean | "null"
# User-defined
`
// FromSchema generates a grammar from a JSON schema.
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
var s *jsonschema.Schema
if err := json.Unmarshal(jsonSchema, &s); err != nil {
return nil, err
}
var g builder
// "root" is the only rule that is guaranteed to exist, so we start
// with its length for padding, and then adjust it as we go.
g.pad = len("root")
for id := range dependencies("root", s) {
g.pad = max(g.pad, len(id))
}
g.b.WriteString(jsonTerms)
ids := make(map[*jsonschema.Schema]string)
for id, s := range dependencies("root", s) {
ids[s] = id
g.define(id)
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
}
g.define("root")
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
g.define("") // finalize the last rule
return g.b.Bytes(), nil
}
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
switch typ := s.EffectiveType(); typ {
case "array":
if len(s.PrefixItems) == 0 && s.Items == nil {
g.u("array")
} else {
g.q("[")
for i, s := range s.PrefixItems {
if i > 0 {
g.q(",")
}
g.u(ids[s])
}
if s.Items != nil {
g.u("(")
if len(s.PrefixItems) > 0 {
g.q(",")
}
g.u(ids[s.Items])
g.u(")*")
}
g.q("]")
}
case "object":
if len(s.Properties) == 0 {
g.u("object")
} else {
g.q("{")
for i, p := range s.Properties {
name := ids[p]
if i > 0 {
g.q(",")
}
g.q(p.Name)
g.q(":")
g.u(name)
}
g.q("}")
}
case "number":
buildConstrainedNumber(g, s)
case "string":
if len(s.Enum) == 0 {
g.u("string")
} else {
g.u("(")
for i, e := range s.Enum {
if i > 0 {
g.q("|")
}
g.q(string(e))
}
g.u(")")
}
case "boolean", "value", "null", "integer":
g.u(typ)
default:
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
}
return nil
}
// dependencies returns a sequence of all child dependencies of the schema in
// post-order.
//
// The first value is the id/pointer to the dependency, and the second value
// is the schema.
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
return func(yield func(string, *jsonschema.Schema) bool) {
for i, p := range s.Properties {
id := fmt.Sprintf("%s_%d", id, i)
for did, d := range dependencies(id, p) {
if !yield(did, d) {
return
}
}
if !yield(id, p) {
return
}
}
for i, p := range s.PrefixItems {
id := fmt.Sprintf("tuple_%d", i)
for did, d := range dependencies(id, p) {
id := fmt.Sprintf("%s_%s", id, did)
if !yield(id, d) {
return
}
}
if !yield(id, p) {
return
}
}
if s.Items != nil {
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
for did, d := range dependencies(id, s.Items) {
if !yield(did, d) {
return
}
}
if !yield(id, s.Items) {
return
}
}
}
}
type builder struct {
b bytes.Buffer
pad int
rules int
items int
}
// define terminates the current rule, if any, and then either starts a new
// rule or does nothing else if the name is empty.
func (b *builder) define(name string) {
if b.rules > 0 {
b.b.WriteString(";\n")
}
if name == "" {
return
}
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
b.b.WriteString(" ::=")
b.rules++
b.items = 0
}
// quote appends a terminal to the current rule.
func (b *builder) q(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(strconv.Quote(s))
}
// u appends a non-terminal to the current rule.
func (b *builder) u(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(s)
}
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
if s.Minimum == 0 && s.Maximum == 0 {
b.u("TODO")
} else {
b.u("number")
}
}

View File

@@ -1,75 +0,0 @@
package grammar
import (
"bufio"
"cmp"
"iter"
"strings"
"testing"
_ "embed"
"github.com/ollama/ollama/grammar/internal/diff"
)
func TestFromSchema(t *testing.T) {
for tt := range testCases(t) {
t.Run(tt.name, func(t *testing.T) {
g, err := FromSchema(nil, []byte(tt.schema))
if err != nil {
t.Fatalf("FromSchema: %v", err)
}
got := string(g)
got = strings.TrimPrefix(got, jsonTerms)
if got != tt.want {
t.Logf("schema:\n%s", tt.schema)
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
}
})
}
}
type testCase struct {
name string
schema string
want string
}
//go:embed testdata/schemas.txt
var tests string
func testCases(t testing.TB) iter.Seq[testCase] {
t.Helper()
return func(yield func(testCase) bool) {
t.Helper()
sc := bufio.NewScanner(strings.NewReader(tests))
name := ""
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" {
name = ""
continue
}
if line[0] == '#' {
name = cmp.Or(name, strings.TrimSpace(line[1:]))
continue
}
s := sc.Text()
g := ""
for sc.Scan() {
line = strings.TrimSpace(sc.Text())
if line == "" || line[0] == '#' {
break
}
g += sc.Text() + "\n"
}
if !yield(testCase{name, s, g}) {
return
}
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
}
if err := sc.Err(); err != nil {
t.Fatalf("error reading tests: %v", err)
}
}
}

View File

@@ -1,261 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"fmt"
"sort"
"strings"
)
// A pair is a pair of values tracked for both the x and y side of a diff.
// It is typically a pair of line indexes.
type pair struct{ x, y int }
// Diff returns an anchored diff of the two texts old and new
// in the “unified diff” format. If old and new are identical,
// Diff returns a nil slice (no output).
//
// Unix diff implementations typically look for a diff with
// the smallest number of lines inserted and removed,
// which can in the worst case take time quadratic in the
// number of lines in the texts. As a result, many implementations
// either can be made to run for a long time or cut off the search
// after a predetermined amount of work.
//
// In contrast, this implementation looks for a diff with the
// smallest number of “unique” lines inserted and removed,
// where unique means a line that appears just once in both old and new.
// We call this an “anchored diff” because the unique lines anchor
// the chosen matching regions. An anchored diff is usually clearer
// than a standard diff, because the algorithm does not try to
// reuse unrelated blank lines or closing braces.
// The algorithm also guarantees to run in O(n log n) time
// instead of the standard O(n²) time.
//
// Some systems call this approach a “patience diff,” named for
// the “patience sorting” algorithm, itself named for a solitaire card game.
// We avoid that name for two reasons. First, the name has been used
// for a few different variants of the algorithm, so it is imprecise.
// Second, the name is frequently interpreted as meaning that you have
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
// when in fact the algorithm is faster than the standard one.
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
if bytes.Equal(old, new) {
return nil
}
x := lines(old)
y := lines(new)
// Print diff header.
var out bytes.Buffer
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
fmt.Fprintf(&out, "--- %s\n", oldName)
fmt.Fprintf(&out, "+++ %s\n", newName)
// Loop over matches to consider,
// expanding each match to include surrounding lines,
// and then printing diff chunks.
// To avoid setup/teardown cases outside the loop,
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
// in the sequence of matches.
var (
done pair // printed up to x[:done.x] and y[:done.y]
chunk pair // start lines of current chunk
count pair // number of lines from each side in current chunk
ctext []string // lines for current chunk
)
for _, m := range tgs(x, y) {
if m.x < done.x {
// Already handled scanning forward from earlier match.
continue
}
// Expand matching lines as far as possible,
// establishing that x[start.x:end.x] == y[start.y:end.y].
// Note that on the first (or last) iteration we may (or definitely do)
// have an empty match: start.x==end.x and start.y==end.y.
start := m
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
start.x--
start.y--
}
end := m
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
end.x++
end.y++
}
// Emit the mismatched lines before start into this chunk.
// (No effect on first sentinel iteration, when start = {0,0}.)
for _, s := range x[done.x:start.x] {
ctext = append(ctext, "-"+s)
count.x++
}
for _, s := range y[done.y:start.y] {
ctext = append(ctext, "+"+s)
count.y++
}
// If we're not at EOF and have too few common lines,
// the chunk includes all the common lines and continues.
const C = 3 // number of context lines
if (end.x < len(x) || end.y < len(y)) &&
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
for _, s := range x[start.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
continue
}
// End chunk with common lines for context.
if len(ctext) > 0 {
n := end.x - start.x
if n > C {
n = C
}
for _, s := range x[start.x : start.x+n] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = pair{start.x + n, start.y + n}
// Format and emit chunk.
// Convert line numbers to 1-indexed.
// Special case: empty file shows up as 0,0 not 1,0.
if count.x > 0 {
chunk.x++
}
if count.y > 0 {
chunk.y++
}
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
for _, s := range ctext {
out.WriteString(s)
}
count.x = 0
count.y = 0
ctext = ctext[:0]
}
// If we reached EOF, we're done.
if end.x >= len(x) && end.y >= len(y) {
break
}
// Otherwise start a new chunk.
chunk = pair{end.x - C, end.y - C}
for _, s := range x[chunk.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
}
return out.Bytes()
}
// lines returns the lines in the file x, including newlines.
// If the file does not end in a newline, one is supplied
// along with a warning about the missing newline.
func lines(x []byte) []string {
l := strings.SplitAfter(string(x), "\n")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
} else {
// Treat last line as having a message about the missing newline attached,
// using the same text as BSD/GNU diff (including the leading backslash).
l[len(l)-1] += "\n\\ No newline at end of file\n"
}
return l
}
// tgs returns the pairs of indexes of the longest common subsequence
// of unique lines in x and y, where a unique line is one that appears
// once in x and once in y.
//
// The longest common subsequence algorithm is as described in
// Thomas G. Szymanski, “A Special Case of the Maximal Common
// Subsequence Problem,” Princeton TR #170 (January 1975),
// available at https://research.swtch.com/tgs170.pdf.
func tgs(x, y []string) []pair {
// Count the number of times each string appears in a and b.
// We only care about 0, 1, many, counted as 0, -1, -2
// for the x side and 0, -4, -8 for the y side.
// Using negative numbers now lets us distinguish positive line numbers later.
m := make(map[string]int)
for _, s := range x {
if c := m[s]; c > -2 {
m[s] = c - 1
}
}
for _, s := range y {
if c := m[s]; c > -8 {
m[s] = c - 4
}
}
// Now unique strings can be identified by m[s] = -1+-4.
//
// Gather the indexes of those strings in x and y, building:
// xi[i] = increasing indexes of unique strings in x.
// yi[i] = increasing indexes of unique strings in y.
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
var xi, yi, inv []int
for i, s := range y {
if m[s] == -1+-4 {
m[s] = len(yi)
yi = append(yi, i)
}
}
for i, s := range x {
if j, ok := m[s]; ok && j >= 0 {
xi = append(xi, i)
inv = append(inv, j)
}
}
// Apply Algorithm A from Szymanski's paper.
// In those terms, A = J = inv and B = [0, n).
// We add sentinel pairs {0,0}, and {len(x),len(y)}
// to the returned sequence, to help the processing loop.
J := inv
n := len(xi)
T := make([]int, n)
L := make([]int, n)
for i := range T {
T[i] = n + 1
}
for i := range n {
k := sort.Search(n, func(k int) bool {
return T[k] >= J[i]
})
T[k] = J[i]
L[i] = k + 1
}
k := 0
for _, v := range L {
if k < v {
k = v
}
}
seq := make([]pair, 2+k)
seq[1+k] = pair{len(x), len(y)} // sentinel at end
lastj := n
for i := n - 1; i >= 0; i-- {
if L[i] == k && J[i] < lastj {
seq[k] = pair{xi[i], yi[J[i]]}
k--
}
}
seq[0] = pair{0, 0} // sentinel at start
return seq
}

View File

@@ -1,44 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"path/filepath"
"testing"
"golang.org/x/tools/txtar"
)
func clean(text []byte) []byte {
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
text = bytes.TrimSuffix(text, []byte("^D\n"))
return text
}
func Test(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txt")
if len(files) == 0 {
t.Fatalf("no testdata")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatal(err)
}
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
t.Fatalf("%s: want three files, third named \"diff\"", file)
}
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
want := clean(a.Files[2].Data)
if !bytes.Equal(diffs, want) {
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
diffs, want, Diff("have", diffs, "want", want))
}
})
}
}

View File

@@ -1,13 +0,0 @@
-- old --
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -0,0 +1,3 @@
+a
+b
+c

View File

@@ -1,13 +0,0 @@
-- old --
a
b
c
-- new --
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +0,0 @@
-a
-b
-c

View File

@@ -1,35 +0,0 @@
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
-- old --
a
b
c
d
e
f
g
-- new --
w
a
b
x
y
z
e
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,7 @@
+w
a
b
-c
-d
+x
+y
+z
e
-f
-g

View File

@@ -1,40 +0,0 @@
-- old --
a
b
c
d
e
f
-- new --
a
B
C
d
e
f
-- diff --
diff old new
--- old
+++ new
@@ -1,8 +1,8 @@
a
$
-b
-
-c
+B
+
+C
$
d
$

View File

@@ -1,38 +0,0 @@
-- old --
1
2
3
4
5
6
7
eight
nine
ten
eleven
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -5,7 +5,6 @@
5
6
7
-eight
-nine
-ten
-eleven
+8
+9
+10

View File

@@ -1,9 +0,0 @@
-- old --
a
b
c^D
-- new --
a
b
c^D
-- diff --

View File

@@ -1,18 +0,0 @@
-- old --
a
b
c
-- new --
a
b
c^D
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
+c
\ No newline at end of file

View File

@@ -1,18 +0,0 @@
-- old --
a
b
c^D
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
\ No newline at end of file
+c

View File

@@ -1,62 +0,0 @@
-- old --
1
2
3
4
5
6
7
8
9
10
11
12
13
14
14½
15
16
17
18
19
20
-- new --
1
2
3
4
5
6
8
9
10
11
12
13
14
17
18
19
20
-- diff --
diff old new
--- old
+++ new
@@ -4,7 +4,6 @@
4
5
6
-7
8
9
10
@@ -12,9 +11,6 @@
12
13
14
-14½
-15
-16
17
18
19

View File

@@ -1,5 +0,0 @@
-- old --
hello world
-- new --
hello world
-- diff --

View File

@@ -1,34 +0,0 @@
-- old --
e
pi
4
5
6
7
8
9
10
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -1,5 +1,6 @@
-e
-pi
+1
+2
+3
4
5
6

View File

@@ -1,40 +0,0 @@
Another example from Hunt and McIlroy,
“An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
Anchored diff gives up on finding anything,
since there are no unique lines.
-- old --
a
b
c
a
b
b
a
-- new --
c
a
b
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,6 @@
-a
-b
-c
-a
-b
-b
-a
+c
+a
+b
+a
+b
+c

View File

@@ -1,171 +0,0 @@
package jsonschema
import (
"bytes"
"encoding/json"
"errors"
)
// Schema holds a JSON schema.
type Schema struct {
// Name is the name of the property. For the parent/root property, this
// is "root". For child properties, this is the name of the property.
Name string `json:"-"`
// Type is the type of the property.
//
// TODO: Union types (e.g. make this a []string).
Type string
// PrefixItems is a list of schemas for each item in a tuple. By
// default, the tuple is "closed." unless Items is set to true or a
// valid Schema.
PrefixItems []*Schema
// Items is the schema for each item in a list.
//
// If it is missing, or its JSON value is "null" or "false", it is nil.
// If the JSON value is "true", it is set to the empty Schema. If the
// JSON value is an object, it will be decoded as a Schema.
Items *Schema
// MinItems specifies the minimum number of items allowed in a list.
MinItems int
// MaxItems specifies the maximum number of items allowed in a list.
MaxItems int
// Properties is the schema for each property of an object.
Properties []*Schema
// Format is the format of the property. This is used to validate the
// property against a specific format.
//
// It is the callers responsibility to validate the property against
// the format.
Format string
// Minimum specifies the minimum value for numeric properties.
Minimum float64
// Maximum specifies the maximum value for numeric properties.
Maximum float64
// Enum is a list of valid values for the property.
Enum []json.RawMessage
}
func (s *Schema) UnmarshalJSON(data []byte) error {
type S Schema
w := struct {
Properties props
Items items
*S
}{
S: (*S)(s),
}
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Items.set {
s.Items = &w.Items.Schema
}
s.Properties = w.Properties
return nil
}
type items struct {
Schema
set bool
}
func (s *items) UnmarshalJSON(data []byte) error {
switch b := data[0]; b {
case 't':
*s = items{set: true}
case '{':
type I items
if err := json.Unmarshal(data, (*I)(s)); err != nil {
return err
}
s.set = true
case 'n', 'f':
default:
return errors.New("invalid Items")
}
return nil
}
// EffectiveType returns the effective type of the schema. If the Type field is
// not empty, it is returned; otherwise:
//
// - If the schema has both Properties and Items, it returns an empty string.
// - If the schema has Properties, it returns "object".
// - If the schema has Items, it returns "array".
// - If the schema has neither Properties nor Items, it returns "value".
//
// The returned string is never empty.
func (d *Schema) EffectiveType() string {
if d.Type == "" {
if len(d.Properties) > 0 {
return "object"
}
if len(d.PrefixItems) > 0 || d.Items != nil {
return "array"
}
return "value"
}
return d.Type
}
// props is an ordered list of properties. The order of the properties
// is the order in which they were defined in the schema.
type props []*Schema
var _ json.Unmarshaler = (*props)(nil)
func (v *props) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] != '{' {
return errors.New("expected object")
}
d := json.NewDecoder(bytes.NewReader(data))
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
// llama.cpp, ignore unknown fields, which could be lead to unexpected
// behavior for clients of this package, since they may not be aware
// that "additionalFields", "itemsPrefix", etc, are being ignored.
//
// For now, just do what llama.cpp does.
t, err := d.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("expected object")
}
for d.More() {
// Use the first token (map key) as the property name, then
// decode the rest of the object fields into a Schema and
// append.
t, err := d.Token()
if err != nil {
return err
}
if t == json.Delim('}') {
return nil
}
s := &Schema{
Name: t.(string),
}
if err := d.Decode(s); err != nil {
return err
}
*v = append(*v, s)
}
return nil
}

View File

@@ -1,104 +0,0 @@
package jsonschema
import (
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
const testSchemaBasic = `
{
"properties": {
"tupleClosedEmpty": { "prefixItems": [] },
"tupleClosedMissing": { "prefixItems": [{}] },
"tupleClosedNull": { "prefixItems": [{}], "items": null },
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
"array": { "items": {"type": "number"} },
"null": { "type": "null" },
"string": { "type": "string" },
"boolean": { "type": "boolean" }
}
}
`
func TestSchemaUnmarshal(t *testing.T) {
var got *Schema
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
want := &Schema{
Properties: []*Schema{
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
{Name: "array", Items: &Schema{Type: "number"}},
{Name: "null", Type: "null"},
{Name: "string", Type: "string"},
{Name: "boolean", Type: "boolean"},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want, +got)\n%s", diff)
}
}
func TestEffectiveType(t *testing.T) {
const schema = `
{"properties": {
"o": {"type": "object"},
"a": {"type": "array"},
"n": {"type": "number"},
"s": {"type": "string"},
"z": {"type": "null"},
"b": {"type": "boolean"},
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
"t1": {"items": {"type": "number"}, "maxItems": 3},
"v": {"maxItems": 3}
}}
`
var s *Schema
if err := json.Unmarshal([]byte(schema), &s); err != nil {
t.Fatalf("json.Unmarshal: %v", err)
}
var got []string
for _, p := range s.Properties {
got = append(got, p.EffectiveType())
}
want := strings.Fields(`
object
array
number
string
null
boolean
array
array
value
`)
if !reflect.DeepEqual(want, got) {
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
}
}

View File

@@ -1,76 +0,0 @@
# This file holds tests for JSON schema to EBNF grammar conversions.
#
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
# name the tests number in the file (e.g. "#0", "#1", etc.)
#
# Blank lines signify the end or start of a new test. Comments can be added
# anywhere in the file, but they must be preceded by a '#' character and start at
# the beginning of the line.
# default
{}
root ::= value;
{"properties": {}}
root ::= value;
# array
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
root_0_tuple_0 ::= string;
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root ::= "{" "a" ":" root_0 "}";
# array with nested array
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
root_tuple_0_tuple_0 ::= string;
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
root ::= "[" ( root_tuple_0 )* "]";
# object
{"properties": {"e": {}}}
root_0 ::= value;
root ::= "{" "e" ":" root_0 "}";
# object with nested object
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
root_0_0 ::= value;
root_0 ::= "{" "e" ":" root_0_0 "}";
root ::= "{" "o" ":" root_0 "}";
# boolean
{"type": "boolean"}
root ::= boolean;
# number
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
root_0 ::= number;
root ::= "{" "n" ":" root_0 "}";
# string
{"type": "string"}
root ::= string;
# string with enum
{"type": "string", "enum": ["a", "b", "c"]}
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
# spaces in key
{"properties": {"a b": {}}}
root_0 ::= value;
root ::= "{" "a b" ":" root_0 "}";
# issue7978
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
root_0_tuple_0_0 ::= string;
root_0_tuple_0_1 ::= string;
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root_1 ::= string;
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
# !! # special characters in key
# !! {"properties": {"a!b": {}}}
# !! !invalid character '!' in key
# !!

View File

@@ -43,13 +43,8 @@ type Cache interface {
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Close closes the cache and frees resources associated with it
Close()
@@ -57,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, batch input.Batch) error
StartForward(ctx ml.Context, opts input.Options) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
opts CausalOptions
@@ -97,7 +98,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -118,16 +119,9 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
}
@@ -146,14 +140,12 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
@@ -165,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
for i, pos := range opts.Positions {
seq := opts.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -218,51 +210,7 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func roundDown(length, pad int) int {
@@ -317,7 +265,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i, key := range c.keys {
if key == nil {
continue
@@ -327,8 +275,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
@@ -336,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
@@ -373,8 +321,7 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
// copy for each of k and v).
layers := 0
for _, key := range c.keys {
if key == nil {
@@ -383,7 +330,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -532,14 +479,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
@@ -550,7 +497,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)

View File

@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF32, 16)
tests := []testCase{
{
name: "FirstBatch",
name: "SlidingWindow",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
@@ -71,16 +71,6 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
@@ -91,7 +81,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -126,7 +116,7 @@ func TestRemove(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -191,7 +181,7 @@ func TestDefrag(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -239,7 +229,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -280,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}

View File

@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -58,10 +58,6 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, ca
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
@@ -83,10 +79,10 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
}
return nil

View File

@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches {
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
cache.Init(backend, dtype, capacity)
}
}
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, opts)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
for k := range opts.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
}
}
return err

View File

@@ -29,7 +29,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/grammar"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
@@ -701,9 +700,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
// User provided a JSON schema
g, err := grammar.FromSchema(nil, req.Format)
if err != nil {
return fmt.Errorf("invalid JSON schema in format: %w", err)
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
req.Grammar = string(g)
}
@@ -714,11 +713,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
req.Options = &opts
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -733,6 +727,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {

View File

@@ -2,7 +2,6 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@@ -61,10 +60,6 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -81,9 +76,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -91,9 +86,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
backends[name] = f
}
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(ctx, f, params)
return backend(f, params)
}
return nil, fmt.Errorf("unsupported backend")

View File

@@ -9,17 +9,15 @@ package ggml
import "C"
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
"unicode"
"unsafe"
@@ -60,7 +58,7 @@ type Backend struct {
maxGraphNodes int
}
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@@ -299,16 +297,12 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
}
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
for _, t := range meta.Tensors().Items() {
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
for _, target := range targets[t.Name] {
g.Go(func() error {
if target == "" {
target = t.Name
}
@@ -318,43 +312,24 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
bts := C.malloc(C.size_t(t.Size()))
if bts == nil {
return errors.New("failed to allocate tensor buffer")
}
defer C.free(bts)
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
if err != nil {
return err
buf := unsafe.Slice((*byte)(bts), t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
if err != nil || n != len(buf) {
return errors.New("read failed")
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
return nil
})
}
}
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
return nil, err
}

View File

@@ -1,7 +1,5 @@
package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
@@ -35,24 +33,11 @@ type MultimodalIndex struct {
Multimodal any
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
Positions []int32
Sequences []int
Outputs []int32
}

View File

@@ -1,7 +1,6 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@@ -27,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Forward(ml.Context, input.Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
@@ -95,14 +94,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
}
@@ -281,30 +280,24 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
}
if len(batch.Positions) < 1 {
if len(opts.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, opts)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
t, err := m.Forward(ctx, opts)
if err != nil {
return nil, err
}

View File

@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
panic("unimplemented")
}

View File

@@ -168,18 +168,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {

View File

@@ -139,18 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}
func init() {

View File

@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
for _, image := range opts.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

View File

@@ -139,18 +139,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)

View File

@@ -135,27 +135,32 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(opts.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
}
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -32,8 +32,6 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocab() *Vocabulary
}
type Vocabulary struct {

View File

@@ -49,10 +49,6 @@ func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
}
}
func (spm SentencePieceModel) Vocab() *Vocabulary {
return spm.vocab
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}

View File

@@ -31,10 +31,8 @@ type InputCache struct {
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
@@ -46,11 +44,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
}
return &InputCache{
numCtx: numCtx,
numCtx: kvSize / int32(numSlots),
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,

View File

@@ -348,8 +348,7 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var batchInputs []int32
var batch input.Batch
var options input.Options
for i, seq := range s.seqs {
if seq == nil {
@@ -396,17 +395,17 @@ func (s *Server) processBatch() error {
}
}
batchInputs = append(batchInputs, inp.Token)
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
}
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
seq.iBatch = len(options.Outputs)
if j+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -414,14 +413,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if len(batchInputs) == 0 {
if len(options.Inputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
modelOutput, err := model.Forward(ctx, s.model, options)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -461,7 +460,7 @@ func (s *Server) processBatch() error {
}
// sample a token
vocabSize := len(logits) / len(batch.Outputs)
vocabSize := len(logits) / len(options.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
@@ -562,14 +561,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
req.Options.TopP,
req.Options.MinP,
req.Options.Seed,
grammar,
)
sampler := sample.NewSampler(req.Options, grammar)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.Options.NumPredict,
@@ -678,7 +670,6 @@ func (m *multiLPath) String() string {
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@@ -688,7 +679,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(ctx, mpath, params)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
@@ -700,7 +691,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented")
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
if err != nil {
panic(err)
}
@@ -784,9 +775,6 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
@@ -795,13 +783,13 @@ func Execute(args []string) error {
}
server.ready.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)

View File

@@ -1,12 +1,14 @@
package sample
import (
"encoding/json"
"errors"
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llama"
)
@@ -126,40 +128,65 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return tokens[idx], nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
// SamplerParams contains the validated and normalized parameters for a sampler
type SamplerParams struct {
Temperature float32 `json:"temperature"`
TopK int `json:"top_k"`
TopP float32 `json:"top_p"`
MinP float32 `json:"min_p"`
Seed int `json:"seed"`
}
// UnmarshalJSON implements json.Unmarshaler to handle validation during JSON unmarshaling
func (p *SamplerParams) UnmarshalJSON(data []byte) error {
type rawParams SamplerParams
if err := json.Unmarshal(data, (*rawParams)(p)); err != nil {
return err
}
// Validate and normalize after unmarshaling
if p.Temperature < 0.0 {
p.Temperature = 0.0
}
if p.TopP < 0.0 {
p.TopP = 0.0
}
if p.TopP >= 1.0 {
p.TopP = 1.0
}
if p.MinP < 0.0 {
p.MinP = 0.0
}
if p.MinP >= 1.0 {
p.MinP = 1.0
}
return nil
}
// NewSampler creates a new sampler with the given options
func NewSampler(opts *api.Options, grammar *Grammar) Sampler {
var params SamplerParams
data, _ := json.Marshal(opts)
_ = json.Unmarshal(data, &params)
var rng *rand.Rand
if seed != -1 {
if params.Seed != -1 {
// PCG requires two parameters: sequence and stream
// Use original seed for sequence
sequence := uint64(seed)
sequence := uint64(params.Seed)
// Use golden ratio hash to generate statistically independent seeds
rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9))
}
if temperature < 0.0 {
temperature = 0.0
}
if topP < 0.0 {
topP = 0.0
}
if topP >= 1.0 {
topP = 1.0
}
if minP < 0.0 {
minP = 0.0
}
if minP >= 1.0 {
minP = 1.0
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
topK: params.TopK,
topP: params.TopP,
minP: params.MinP,
temperature: params.Temperature,
grammar: grammar,
}
}

View File

@@ -16,7 +16,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0.8, 0, 0, 0, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 0, 0, 0, 42), nil)
b.ResetTimer()
for b.Loop() {
sampler.Sample(logits)
@@ -49,7 +49,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
for _, tc := range configs {
b.Run("Config"+tc.name, func(b *testing.B) {
sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil)
sampler := NewSampler(createSamplerOptions(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed), nil)
sampler.Sample(logits)
b.ResetTimer()
@@ -62,7 +62,7 @@ func BenchmarkWeightedSampler(b *testing.B) {
// Test with combined transforms separately - topK influences performance greatly
b.Run("TransformCombined", func(b *testing.B) {
sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil)
sampler := NewSampler(createSamplerOptions(0.8, 50, 0.9, 0.05, 42), nil)
b.ResetTimer()
for b.Loop() {
@@ -81,7 +81,7 @@ func BenchmarkGreedySampler(b *testing.B) {
logits[i] = float32(rand.Float64()*10 - 5)
}
sampler := NewSampler(0, -1, 0, 0, -1, nil)
sampler := NewSampler(createSamplerOptions(0, -1, 0, 0, -1), nil)
b.ResetTimer()
for b.Loop() {

View File

@@ -4,11 +4,23 @@ import (
"math"
"math/rand/v2"
"testing"
"github.com/ollama/ollama/api"
)
func createSamplerOptions(temperature float32, topK int, topP float32, minP float32, seed int) *api.Options {
return &api.Options{
Temperature: temperature,
TopK: topK,
TopP: topP,
MinP: minP,
Seed: seed,
}
}
func TestWeighted(t *testing.T) {
logits := []float32{-10, 3, -10, -10}
sampler := NewSampler(0, 0, 0, 0, 0, nil)
sampler := NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err := sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -20,7 +32,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{-100, -10, 0, 10}
sampler = NewSampler(0, 0, 0, 0, 0, nil)
sampler = NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -34,7 +46,7 @@ func TestWeighted(t *testing.T) {
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
sampler = NewSampler(createSamplerOptions(1.0, 0, 1e-10, 0, 0), nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
@@ -47,7 +59,7 @@ func TestWeighted(t *testing.T) {
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
sampler = NewSampler(createSamplerOptions(1, 0, 0.95, 0.05, 0), nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
@@ -57,8 +69,8 @@ func TestWeighted(t *testing.T) {
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil),
"Greedy": NewSampler(createSamplerOptions(0, 0, 0, 0, 0), nil), // Use NewSampler with temp=0 for greedy
"Weighted": NewSampler(createSamplerOptions(0.5, 10, 0.9, 0.2, -1), nil),
}
// Generate random logits for benchmarking

View File

@@ -1,176 +0,0 @@
package sample
import (
"bytes"
"strings"
"github.com/ollama/ollama/model"
)
type Node struct {
TransitionEdges map[rune]*Node
}
type Graph struct {
proc model.TextProcessor
decodedToks []string
curNode *Node
grammar []byte
rules map[string]string
}
// baseRules is the set of rules that are used to parse the grammar
// JSON grammar from RFC 7159
var baseRules = map[string]string{
"object": "\"{\" (kv (\",\" kv)*)? \"}\"",
"array": "\"[\" (value (\",\" value)*)? \"]\"",
"string": "\"\\\"\" char* \"\\\"\"",
"number": "\"-\"? integer frac? exp?",
"kv": "string \":\" value",
"integer": "\"0\" | [1-9] [0-9]*",
"frac": "\".\" [0-9]+",
"exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+",
"escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]",
"char": "[^\"\\\\] | escape",
"space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*",
"hex": "[0-9] | [a-f] | [A-F]",
"boolean": "\"true\" | \"false\"",
"value": "object | array | string | number | boolean | \"null\"",
"null": "\"null\"",
}
func (g *Graph) BuildGraph(node *Node) error {
vocab := g.proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := g.proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
g.decodedToks = decodedToks
g.rules = baseRules
g.rootPrefixes()
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
g.parseRule(g.rules["root"], rootNode)
return nil
}
// rootPrefixes extracts all root prefixes from the grammar
// and parses the grammar string to extract root prefixes
func (g *Graph) rootPrefixes() {
lines := bytes.Split(g.grammar, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) {
continue
}
parts := bytes.SplitN(line, []byte("::="), 2)
if len(parts) != 2 {
continue
}
ruleName := string(bytes.TrimSpace(parts[0]))
if strings.HasPrefix(ruleName, "root") {
g.rules[ruleName] = string(bytes.TrimSpace(parts[1]))
}
}
}
// parseRule parses a grammar rule and returns a Node
func (g *Graph) parseRule(rule string, curNode *Node) *Node {
/*
Here are the special characters in BNF grammar and their functions:
::= - Definition operator, means "is defined as"
| - Alternation, means "or"
* - Zero or more repetitions of preceding element
+ - One or more repetitions
? - Optional (zero or one occurrence)
[] - Character class, matches any single character within brackets
[^] - Negated character class, matches any character NOT listed
() - Grouping of elements
- - Range operator in character classes (e.g., [a-z])
"" - Literal string match
*/
// Split rule into tokens by whitespace
tokens := strings.Fields(rule)
if len(tokens) == 0 {
return &Node{
TransitionEdges: make(map[rune]*Node),
}
}
// Handle integer rule
if strings.Contains(rule, "[0-9]+") {
// Create node for first digit 1-9
firstDigitNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '1'; r <= '9'; r++ {
curNode.TransitionEdges[r] = firstDigitNode
}
// Create node for subsequent digits 0-9
zeroToNineNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '0'; r <= '9'; r++ {
// Loop back to same node for * operator
zeroToNineNode.TransitionEdges[r] = zeroToNineNode
}
// Connect first digit to subsequent digits
firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges
// Also handle the "0" case
if strings.Contains(rule, "\"0\"") {
zeroNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges['0'] = zeroNode
}
return curNode
}
// recursive case
// grammar options
// TODO: handle left recursion
if strings.Contains(rule, "|") {
parts := strings.Split(rule, "|")
savedNode := curNode
for _, part := range parts {
// TODO: add correct transitions
g.parseRule(part, savedNode)
}
}
for _, token := range tokens {
if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") {
token = strings.Trim(token, "\"")
for _, r := range token {
newNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges[r] = newNode
curNode = newNode
}
// strNode := &Node{
// TransitionEdges: make(map[rune]*Node),
// }
// TODO: length constraint
// to self
}
}
return curNode
}

View File

@@ -1,3 +0,0 @@
package sample
type StructuredOutput struct{}

View File

@@ -1,194 +0,0 @@
package sample
import (
"testing"
"github.com/ollama/ollama/model"
)
func TestBuildGraph(t *testing.T) {
tests := []struct {
name string
grammar []byte
wantErr bool
}{
{
name: "empty grammar",
grammar: []byte{},
wantErr: false,
},
{
name: "valid grammar",
grammar: []byte(`root ::= value
value ::= string | number`),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
proc: &mockProcessor{},
grammar: tt.grammar,
rules: make(map[string]string),
}
node := &Node{
TransitionEdges: make(map[rune]*Node),
}
err := g.BuildGraph(node)
if (err != nil) != tt.wantErr {
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
if len(g.decodedToks) == 0 {
t.Error("Expected decoded tokens, got none")
}
if len(g.rules) == 0 {
t.Error("Expected rules to be populated")
}
}
})
}
}
func TestRootPrefixes(t *testing.T) {
tests := []struct {
name string
grammar []byte
expected map[string]string
}{
{
name: "empty grammar",
grammar: []byte{},
expected: map[string]string{},
},
{
name: "grammar with root prefix",
grammar: []byte(`root ::= value
root_string ::= string`),
expected: map[string]string{
"root": "value",
"root_string": "string",
},
},
{
name: "grammar with comments and empty lines",
grammar: []byte(`# comment
root ::= value
# another comment
root_number ::= number`),
expected: map[string]string{
"root": "value",
"root_number": "number",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
grammar: tt.grammar,
rules: make(map[string]string),
}
g.rootPrefixes()
for k, v := range tt.expected {
if actual, ok := g.rules[k]; !ok || actual != v {
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
}
}
})
}
}
func TestParseRule(t *testing.T) {
tests := []struct {
name string
rule string
expected string
}{
{
name: "empty rule",
rule: "",
expected: "",
},
{
name: "simple string",
rule: "root ::= \"test_string\"",
expected: "test_string",
},
{
name: "simple string",
rule: "root ::= \"test_string\" | \"test_string2\"",
expected: "test_stringtest_string2",
},
{
name: "integer",
rule: "root ::= [0-9]+",
// TODO: this is infinite acutally
expected: "0123456789",
},
// TODO: handle left recursion
// {
// name: "left recursion",
// rule: "root ::= root \"test_string\"",
// expected: "test_string",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
rules: make(map[string]string),
}
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode := rootNode
g.parseRule(tt.rule, curNode)
sb := ""
for {
if len(curNode.TransitionEdges) == 0 {
break
}
for r, n := range curNode.TransitionEdges {
sb += string(r)
curNode = n
}
t.Logf("sb: %s", sb)
}
if sb != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, sb)
}
})
}
}
// mockProcessor implements the TextProcessor interface for testing
type mockProcessor struct{}
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
return "test", nil
}
func (m *mockProcessor) Vocab() *model.Vocabulary {
return &model.Vocabulary{
Values: []string{"test1", "test2"},
}
}
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
return []int32{0, 1}, nil
}
func (m *mockProcessor) Is(token int32, special model.Special) bool {
return false
}

View File

@@ -59,11 +59,6 @@ var (
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
)
// Defaults
@@ -276,19 +271,8 @@ func DefaultRegistry() (*Registry, error) {
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
version,
buildinfo.Main.Version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
@@ -434,14 +418,13 @@ func canRetry(err error) bool {
//
// It always calls update with a nil error.
type trackingReader struct {
l *Layer
r io.Reader
update func(l *Layer, n int64, err error)
r io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(r.l, int64(n), nil)
r.n.Add(int64(n))
return
}
@@ -479,50 +462,43 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx)
for _, l := range layers {
t.update(l, 0, nil)
expected += l.Size
}
var received atomic.Int64
var g errgroup.Group
g.SetLimit(r.maxStreams())
for _, l := range layers {
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
received.Add(l.Size)
t.update(l, l.Size, ErrCached)
continue
}
var wg sync.WaitGroup
chunked, err := c.Chunked(l.Digest, l.Size)
if err != nil {
t.update(l, 0, err)
continue
}
// TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream interrupted. Note in trace
// log and let in-flight downloads complete.
// This will naturally trigger ErrIncomplete
// since received < expected bytes.
t.update(l, 0, err)
break
// Bad chunksums response, update tracing
// clients and then bail.
t.update(l, progress.Load(), err)
return err
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
received.Add(cs.Chunk.Size())
} else {
if err != nil {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
}
wg.Done()
t.update(l, progress.Load(), err)
}()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
@@ -536,35 +512,25 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
}
defer res.Body.Close()
body := &trackingReader{l: l, r: res.Body, update: t.update}
// Count bytes towards progress, as they
// arrive, so that our bytes piggyback other
// chunk updates on completion.
//
// This tactic is enough to show "smooth"
// progress given the current CLI client. In
// the near future, the server should report
// download rate since it knows better than a
// client that is measuring rate based on
// wall-clock time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
return chunked.Put(cs.Chunk, cs.Digest, body)
})
}
// Close writer immediately after downloads finish, not at Pull
// exit. Using defer would keep file descriptors open until all
// layers complete, potentially exhausting system limits with
// many layers.
//
// The WaitGroup tracks when all chunks finish downloading,
// allowing precise writer closure in a background goroutine.
// Each layer briefly uses one extra goroutine while at most
// maxStreams()-1 chunks download in parallel.
//
// This caps file descriptors at maxStreams() instead of
// growing with layer count.
g.Go(func() error {
wg.Wait()
chunked.Close()
return nil
})
}
if err := g.Wait(); err != nil {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
}
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
@@ -791,12 +757,15 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
}
blobURL := res.Header.Get("Content-Location")
var size int64
s := bufio.NewScanner(res.Body)
s.Split(bufio.ScanWords)
for {
if !s.Scan() {
if s.Err() != nil {
yield(chunksum{}, s.Err())
} else if size != l.Size {
yield(chunksum{}, fmt.Errorf("size mismatch: layer size %d != sum of chunks %d", size, l.Size))
}
return
}
@@ -820,6 +789,12 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return
}
size += chunk.Size()
if size > l.Size {
yield(chunksum{}, fmt.Errorf("chunk size %d exceeds layer size %d", size, l.Size))
return
}
cs := chunksum{
URL: blobURL,
Chunk: chunk,

View File

@@ -25,28 +25,6 @@ import (
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
@@ -835,13 +813,8 @@ func TestPullChunksums(t *testing.T) {
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
if !slices.Equal(reads, []int64{0, 3, 5}) {
t.Errorf("reads = %v; want %v", reads, []int64{0, 3, 5})
}
mw, err := rc.Resolve(t.Context(), "test")

View File

@@ -200,7 +200,7 @@ type params struct {
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively set it to false. So, we use a pointer to a
// if the client decisively it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() {
pushUpdate := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
// TODO(bmizerany): This scales poorly with more layers due to
// needing to flush out them all in one big update. We _could_
// just flush on the changed ones, or just track the whole
// download. Needs more thought. This is fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progressCopy {
for l, n := range progress {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
@@ -298,26 +298,19 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
}
}
defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
t := time.NewTicker(time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
pushUpdate()
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
start() // flush initial state
}
mu.Lock()
progress[l] += n
progress[l] = n
mu.Unlock()
},
})
@@ -330,9 +323,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
for {
select {
case <-t.C:
flushProgress()
pushUpdate()
case err := <-done:
flushProgress()
pushUpdate()
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {