Compare commits

..

5 Commits

Author SHA1 Message Date
Bruce MacDonald
ed14ce2db8 convert mistral-3.1-2503 2025-03-18 15:16:23 -07:00
Bruce MacDonald
f94155fba2 do not add both consolidated and parts to model 2025-03-17 16:33:43 -07:00
jmorganca
8025781dce wip 2025-03-17 10:57:10 -07:00
jmorganca
afb34b0e60 wip 2025-03-17 10:56:20 -07:00
Bruce MacDonald
191b1b1eb3 model: support for mistral-small in the ollama runner
Mistral is a popular research lab making open source models. This updates
the forward pass of llama architecture models to support both llama models
and mistral models by accounting for additional metadata present in mistral
models, and finding the correct dimensions for the output projection.
2025-03-17 10:56:20 -07:00
43 changed files with 1137 additions and 1101 deletions

View File

@@ -1,178 +0,0 @@
package benchmark
import (
"context"
"flag"
"fmt"
"testing"
"time"
"github.com/ollama/ollama/api"
)
// Command line flags
var modelFlag string
func init() {
flag.StringVar(&modelFlag, "m", "", "Name of the model to benchmark")
flag.Lookup("m").DefValue = "model"
}
// modelName returns the model name from flags, failing the test if not set
func modelName(b *testing.B) string {
if modelFlag == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
return modelFlag
}
type TestCase struct {
name string
prompt string
maxTokens int
}
// runGenerateBenchmark contains the common generate and metrics logic
func runGenerateBenchmark(b *testing.B, ctx context.Context, client *api.Client, req *api.GenerateRequest) {
start := time.Now()
var ttft time.Duration
var metrics api.Metrics
err := client.Generate(ctx, req, func(resp api.GenerateResponse) error {
if ttft == 0 && resp.Response != "" {
ttft = time.Since(start)
}
if resp.Done {
metrics = resp.Metrics
}
return nil
})
// Report custom metrics as part of the benchmark results
b.ReportMetric(float64(ttft.Milliseconds()), "ttft_ms")
b.ReportMetric(float64(metrics.LoadDuration.Milliseconds()), "load_ms")
// Token throughput metrics
promptThroughput := float64(metrics.PromptEvalCount) / metrics.PromptEvalDuration.Seconds()
genThroughput := float64(metrics.EvalCount) / metrics.EvalDuration.Seconds()
b.ReportMetric(promptThroughput, "prompt_tok/s")
b.ReportMetric(genThroughput, "gen_tok/s")
// Token counts
b.ReportMetric(float64(metrics.PromptEvalCount), "prompt_tokens")
b.ReportMetric(float64(metrics.EvalCount), "gen_tokens")
if err != nil {
b.Fatal(err)
}
}
// BenchmarkColdStart runs benchmarks with model loading from cold state
func BenchmarkColdStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/cold/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
b.StopTimer()
// Ensure model is unloaded before each iteration
unload(client, m, b)
b.StartTimer()
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]interface{}{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// BenchmarkWarmStart runs benchmarks with pre-loaded model
func BenchmarkWarmStart(b *testing.B) {
client := setup(b)
tests := []TestCase{
{"short_prompt", "Write a long story", 100},
{"medium_prompt", "Write a detailed economic analysis", 500},
{"long_prompt", "Write a comprehensive AI research paper", 1000},
}
m := modelName(b)
for _, tt := range tests {
b.Run(fmt.Sprintf("%s/warm/%s", m, tt.name), func(b *testing.B) {
ctx := context.Background()
// Pre-warm the model
warmup(client, m, tt.prompt, b)
// Set number of tokens as our throughput metric
b.SetBytes(int64(tt.maxTokens))
for b.Loop() {
req := &api.GenerateRequest{
Model: m,
Prompt: tt.prompt,
Options: map[string]any{"num_predict": tt.maxTokens, "temperature": 0.1},
}
runGenerateBenchmark(b, ctx, client, req)
}
})
}
}
// setup verifies server and model availability
func setup(b *testing.B) *api.Client {
client, err := api.ClientFromEnvironment()
if err != nil {
b.Fatal(err)
}
if _, err := client.Show(context.Background(), &api.ShowRequest{Model: modelName(b)}); err != nil {
b.Fatalf("Model unavailable: %v", err)
}
return client
}
// warmup ensures the model is loaded and warmed up
func warmup(client *api.Client, model string, prompt string, b *testing.B) {
for range 3 {
err := client.Generate(
context.Background(),
&api.GenerateRequest{
Model: model,
Prompt: prompt,
Options: map[string]interface{}{"num_predict": 50, "temperature": 0.1},
},
func(api.GenerateResponse) error { return nil },
)
if err != nil {
b.Logf("Error during model warm-up: %v", err)
}
}
}
// unload forces model unloading using KeepAlive: 0 parameter
func unload(client *api.Client, model string, b *testing.B) {
req := &api.GenerateRequest{
Model: model,
KeepAlive: &api.Duration{Duration: 0},
}
if err := client.Generate(context.Background(), req, func(api.GenerateResponse) error { return nil }); err != nil {
b.Logf("Unload error: %v", err)
}
time.Sleep(1 * time.Second)
}

View File

@@ -703,8 +703,6 @@ func showInfo(resp *api.ShowResponse, verbose bool, w io.Writer) error {
for _, k := range keys {
var v string
switch vData := resp.ModelInfo[k].(type) {
case bool:
v = fmt.Sprintf("%t", vData)
case string:
v = vData
case float64:

View File

@@ -87,8 +87,6 @@ func TestShowInfo(t *testing.T) {
ModelInfo: map[string]any{
"general.architecture": "test",
"general.parameter_count": float64(8_000_000_000),
"some.true_bool": true,
"some.false_bool": false,
"test.context_length": float64(1000),
"test.embedding_length": float64(11434),
},
@@ -113,8 +111,6 @@ func TestShowInfo(t *testing.T) {
Metadata
general.architecture test
general.parameter_count 8e+09
some.false_bool false
some.true_bool true
test.context_length 1000
test.embedding_length 11434

View File

@@ -184,6 +184,8 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistralModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":
@@ -201,7 +203,7 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
case "CohereForCausalLM":
conv = &commandrModel{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
return errors.New("unsupported architecture")
}
if err := json.Unmarshal(bts, conv); err != nil {

246
convert/convert_mistral.go Normal file
View File

@@ -0,0 +1,246 @@
package convert
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistralModel struct {
ModelParameters
// Text model parameters
TextConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
} `json:"text_config"`
// Vision model parameters
VisionConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
// Multimodal specific parameters
ImageTokenIndex uint32 `json:"image_token_index"`
MultimodalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
// For RoPE scaling if needed
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
}
func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral"
kv["mistral.vocab_size"] = p.VocabSize
kv["mistral.image_token_index"] = p.ImageTokenIndex
kv["mistral.multimodal_projector_bias"] = p.MultimodalProjectorBias
kv["mistral.projector_hidden_act"] = p.ProjectorHiddenAct
kv["mistral.spatial_merge_size"] = p.SpatialMergeSize
// kv["mistral.vision_feature_layer"] = p.VisionFeatureLayer
// Text model config
kv["mistral.block_count"] = p.TextConfig.NumHiddenLayers
kv["mistral.context_length"] = p.TextConfig.MaxPositionEmbeddings
kv["mistral.embedding_length"] = p.TextConfig.HiddenSize
kv["mistral.feed_forward_length"] = p.TextConfig.IntermediateSize
kv["mistral.attention.head_count"] = p.TextConfig.NumAttentionHeads
kv["mistral.attention.head_count_kv"] = p.TextConfig.NumKeyValueHeads
kv["mistral.rope.dimension_count"] = p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
kv["mistral.rope.freq_base"] = p.TextConfig.RopeTheta
kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextConfig.RMSNormEPS
kv["mistral.attention.key_length"] = p.TextConfig.HeadDim
kv["mistral.attention.value_length"] = p.TextConfig.HeadDim
// Vision model config
kv["mistral.vision.block_count"] = p.VisionConfig.NumHiddenLayers
kv["mistral.vision.embedding_length"] = p.VisionConfig.HiddenSize
kv["mistral.vision.feed_forward_length"] = p.VisionConfig.IntermediateSize
kv["mistral.vision.attention.head_count"] = p.VisionConfig.NumAttentionHeads
kv["mistral.vision.image_size"] = p.VisionConfig.ImageSize
kv["mistral.vision.patch_size"] = p.VisionConfig.PatchSize
kv["mistral.vision.rope.freq_base"] = p.VisionConfig.RopeTheta
// If RoPE scaling is present
if p.RopeScaling.Type == "linear" {
kv["mistral.rope.scaling.type"] = p.RopeScaling.Type
kv["mistral.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.TextConfig.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
}
return kv
}
func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts {
// Process tensors that require repacking
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
// Add all tensors to output
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistralModel) Replacements() []string {
return []string{
// Language model replacements
"language_model.model.embed_tokens", "token_embd",
"language_model.model.norm", "output_norm",
"language_model.model.layers", "blk",
"language_model.model.layers.*.input_layernorm", "input_layernorm",
"language_model.model.layers.*.self_attn.q_proj", "self_attn.q_proj",
"language_model.model.layers.*.self_attn.k_proj", "self_attn.k_proj",
"language_model.model.layers.*.self_attn.v_proj", "self_attn.v_proj",
"language_model.model.layers.*.self_attn.o_proj", "self_attn.o_proj",
"language_model.model.layers.*.mlp.gate_proj", "mlp.gate_proj",
"language_model.model.layers.*.mlp.down_proj", "mlp.down_proj",
"language_model.model.layers.*.mlp.up_proj", "mlp.up_proj",
"language_model.model.layers.*.post_attention_layernorm", "post_attention_layernorm",
"language_model.lm_head", "output",
// Vision model replacements - map to shorter prefixes
"vision_tower", "v",
"multi_modal_projector", "mm",
// Vision transformer blocks - these should be updated accordingly
"vision_tower.transformer.layers", "v.blk",
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
"vision_tower.ln_pre", "v.encoder_norm",
"vision_tower.patch_conv", "v.patch_conv",
// Multimodal projector components
"multi_modal_projector.patch_merger", "mm.patch_merger",
"multi_modal_projector.norm", "mm.norm",
}
}
func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = p.TextConfig.NumAttentionHeads
}
} else if strings.HasSuffix(name, "attn_k.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = cmp.Or(p.TextConfig.NumKeyValueHeads, p.TextConfig.NumAttentionHeads)
}
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -558,10 +558,6 @@ Final response:
{
"model": "llama3.2",
"created_at": "2023-08-04T19:22:45.499127Z",
"message": {
"role": "assistant",
"content": ""
},
"done": true,
"total_duration": 4883583458,
"load_duration": 1334875,

View File

@@ -1,59 +0,0 @@
# Benchmark
Go benchmark tests that measure end-to-end performance of a running Ollama server. Run these tests to evaluate model inference performance on your hardware and measure the impact of code changes.
## When to use
Run these benchmarks when:
- Making changes to the model inference engine
- Modifying model loading/unloading logic
- Changing prompt processing or token generation code
- Implementing a new model architecture
- Testing performance across different hardware setups
## Prerequisites
- Ollama server running locally with `ollama serve` on `127.0.0.1:11434`
## Usage and Examples
>[!NOTE]
>All commands must be run from the root directory of the Ollama project.
Basic syntax:
```bash
go test -bench=. ./benchmark/... -m $MODEL_NAME
```
Required flags:
- `-bench=.`: Run all benchmarks
- `-m`: Model name to benchmark
Optional flags:
- `-count N`: Number of times to run the benchmark (useful for statistical analysis)
- `-timeout T`: Maximum time for the benchmark to run (e.g. "10m" for 10 minutes)
Common usage patterns:
Single benchmark run with a model specified:
```bash
go test -bench=. ./benchmark/... -m llama3.3
```
## Output metrics
The benchmark reports several key metrics:
- `gen_tok/s`: Generated tokens per second
- `prompt_tok/s`: Prompt processing tokens per second
- `ttft_ms`: Time to first token in milliseconds
- `load_ms`: Model load time in milliseconds
- `gen_tokens`: Total tokens generated
- `prompt_tokens`: Total prompt tokens processed
Each benchmark runs two scenarios:
- Cold start: Model is loaded from disk for each test
- Warm start: Model is pre-loaded in memory
Three prompt lengths are tested for each scenario:
- Short prompt (100 tokens)
- Medium prompt (500 tokens)
- Long prompt (1000 tokens)

View File

@@ -43,13 +43,8 @@ type Cache interface {
// ** cache management **
// Init sets up runtime parameters.
// backend: Used to allocate cache data storage and execute management operations (such as defrag)
// dtype: The data type for storing cache entries
// maxSequences: The maximum number of sequences stored in the cache - across all batches
// capacity: The number of cache entries to store, per sequence
// maxBatch: The maximum number of tokens that can occur in a single batch
Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int)
// Init sets up runtime parameters
Init(backend ml.Backend, dtype ml.DType, capacity int32)
// Close closes the cache and frees resources associated with it
Close()
@@ -57,7 +52,7 @@ type Cache interface {
// StartForward is called before the start of the model's forward pass.
// For each token in the coming batch, there must be a corresponding
// entry in positions and seqs.
StartForward(ctx ml.Context, batch input.Batch) error
StartForward(ctx ml.Context, opts input.Options) error
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
CopyPrefix(srcSeq, dstSeq int, len int32)

View File

@@ -20,6 +20,7 @@ type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, e
// The mask is of shape history size, batch size
type Causal struct {
DType ml.DType
Capacity int32
windowSize int32
opts CausalOptions
@@ -97,7 +98,7 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -118,16 +119,9 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
c.config.MaskDType = ml.DTypeF32
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)
c.DType = dtype
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
}
@@ -146,14 +140,12 @@ func (c *Causal) Close() {
}
}
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
c.curBatchSize = len(batch.Positions)
c.curSequences = batch.Sequences
c.curPositions = batch.Positions
func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error {
c.curBatchSize = len(opts.Positions)
c.curSequences = opts.Sequences
c.curPositions = opts.Positions
c.opts.Except = nil
c.updateSlidingWindow()
var err error
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
@@ -165,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
}
c.curCellRange = newRange()
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
for i, pos := range opts.Positions {
seq := opts.Sequences[i]
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
@@ -218,51 +210,7 @@ func (c *Causal) findStartLoc() (int, error) {
}
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, len(c.cells))
}
func (c *Causal) updateSlidingWindow() {
if c.windowSize == math.MaxInt32 {
return
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
}
lowestPos[seq] = pos
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
}
newRange := newRange()
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.windowSize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
}
}
c.cellRanges[seq] = newRange
}
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func roundDown(length, pad int) int {
@@ -317,7 +265,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
return maskTensor, nil
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i, key := range c.keys {
if key == nil {
continue
@@ -327,8 +275,8 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
@@ -336,14 +284,14 @@ func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
@@ -373,8 +321,7 @@ func (c *Causal) defrag() {
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
// copy for each of k and v).
layers := 0
for _, key := range c.keys {
if key == nil {
@@ -383,7 +330,7 @@ func (c *Causal) defrag() {
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
maxMoves := ctx.MaxGraphNodes() / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
@@ -532,14 +479,14 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
if _, ok := c.keys[c.curLayer]; !ok {
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, len(c.cells))
c.keys[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
}
if _, ok := c.values[c.curLayer]; !ok {
if c.config.PermutedV {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, len(c.cells), vHeadDim, numKVHeads)
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, len(c.cells))
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
@@ -550,7 +497,7 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)

View File

@@ -25,7 +25,7 @@ func TestStore(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -58,11 +58,11 @@ func TestSWA(t *testing.T) {
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF32, 16)
tests := []testCase{
{
name: "FirstBatch",
name: "SlidingWindow",
in: []float32{1, 2, 3, 4},
inShape: []int{1, 1, 4},
seqs: []int{0, 0, 0, 0},
@@ -71,16 +71,6 @@ func TestSWA(t *testing.T) {
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0},
},
{
name: "SecondBatch",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{5, 6, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{0, float32(math.Inf(-1)), float32(math.Inf(-1)), 0, 0, 0, float32(math.Inf(-1)), float32(math.Inf(-1))},
},
}
testCache(t, backend, cache, tests)
@@ -91,7 +81,7 @@ func TestSequences(t *testing.T) {
cache := NewCausalCache(nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -126,7 +116,7 @@ func TestRemove(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -191,7 +181,7 @@ func TestDefrag(t *testing.T) {
})
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -239,7 +229,7 @@ func TestCopy(t *testing.T) {
cache := NewCausalCache(func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { return key, nil })
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 1, 16, 16)
cache.Init(backend, ml.DTypeF16, 16)
tests := []testCase{
{
@@ -280,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
context := backend.NewContext()
defer context.Close()
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs})
if err != nil {
panic(err)
}

View File

@@ -49,7 +49,7 @@ func NewEncoderCache() *EncoderCache {
}
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
@@ -58,10 +58,6 @@ func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, ca
c.config = &config
}
if maxSequences > 1 {
panic(fmt.Errorf("encoder cache does not support multiple sequences; requested: %v", maxSequences))
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
@@ -83,10 +79,10 @@ func (c *EncoderCache) Close() {
}
}
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error {
// We work with the most recent image
if len(batch.Multimodal) > 0 {
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
if len(opts.Multimodal) > 0 {
c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index]
}
return nil

View File

@@ -23,9 +23,9 @@ func NewWrapperCache(caches ...Cache) *WrapperCache {
}
}
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {
func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
for _, cache := range c.caches {
cache.Init(backend, dtype, maxSequences, capacity, maxBatch)
cache.Init(backend, dtype, capacity)
}
}
@@ -41,14 +41,14 @@ func (c *WrapperCache) Close() {
}
}
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error {
for i, cache := range c.caches {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, opts)
if err != nil {
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
for j := i - 1; j >= 0; j-- {
for k := range batch.Positions {
_ = c.caches[j].Remove(batch.Sequences[k], batch.Positions[k], math.MaxInt32)
for k := range opts.Positions {
_ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32)
}
}
return err

View File

@@ -2,7 +2,6 @@ package ml
import (
"bytes"
"context"
"encoding/binary"
"fmt"
"os"
@@ -61,10 +60,6 @@ type CacheConfig struct {
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// Progress is a callback function that allows reporting percentage completion
// of model loading
Progress func(float32)
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
@@ -81,9 +76,9 @@ type BackendParams struct {
FlashAttention bool
}
var backends = make(map[string]func(context.Context, *os.File, BackendParams) (Backend, error))
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(context.Context, *os.File, BackendParams) (Backend, error)) {
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok {
panic("backend: backend already registered")
}
@@ -91,9 +86,9 @@ func RegisterBackend(name string, f func(context.Context, *os.File, BackendParam
backends[name] = f
}
func NewBackend(ctx context.Context, f *os.File, params BackendParams) (Backend, error) {
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
if backend, ok := backends["ggml"]; ok {
return backend(ctx, f, params)
return backend(f, params)
}
return nil, fmt.Errorf("unsupported backend")

View File

@@ -9,17 +9,15 @@ package ggml
import "C"
import (
"context"
"errors"
"fmt"
"io"
"log/slog"
"maps"
"os"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
"unicode"
"unsafe"
@@ -60,7 +58,7 @@ type Backend struct {
maxGraphNodes int
}
func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend, error) {
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1)
if err != nil {
return nil, err
@@ -299,16 +297,12 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
}
}
var doneBytes atomic.Uint64
totalBytes := uint64(n) - meta.Tensors().Offset
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(runtime.GOMAXPROCS(0))
// concurrently read in tensor data. uses a section reader which is safe for concurrent reads
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset))
var g errgroup.Group
for _, t := range meta.Tensors().Items() {
g.Go(func() error {
tts := make([]*C.struct_ggml_tensor, max(1, len(targets[t.Name])))
for i := range tts {
target := targets[t.Name][i]
for _, target := range targets[t.Name] {
g.Go(func() error {
if target == "" {
target = t.Name
}
@@ -318,44 +312,23 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
tts[i] = tt
}
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
bts := make([]byte, 128*format.KibiByte)
var s uint64
for s < t.Size() {
n, err := io.ReadFull(sr, bts[:min(len(bts), int(t.Size()-s))])
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
if err != nil {
return err
}
for _, tt := range tts {
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), C.size_t(s), C.size_t(n))
if n != len(bts) {
return errors.New("short read")
}
s += uint64(n)
if params.Progress != nil {
done := doneBytes.Add(uint64(n))
params.Progress(float32(done) / float32(totalBytes))
}
}
return nil
})
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
return nil
})
}
}
// start a goroutine to cancel the errgroup if the parent context is done
go func() {
<-ctx.Done()
g.Go(func() error {
return ctx.Err()
})
}()
if err := g.Wait(); err != nil {
if g.Wait() != nil {
return nil, err
}
@@ -398,7 +371,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
true,
),
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],

View File

@@ -1,7 +1,5 @@
package input
import "github.com/ollama/ollama/ml"
// Input represents one token in the input stream
type Input struct {
// Token is a single element of text.
@@ -35,24 +33,11 @@ type MultimodalIndex struct {
Multimodal any
}
// Batch contains the inputs for a model forward pass
type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
// Options contains the inputs for a model forward pass
type Options struct {
Inputs []int32
Multimodal []MultimodalIndex
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
Positions []int32
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
Positions []int32
Sequences []int
Outputs []int32
}

View File

@@ -1,7 +1,6 @@
package model
import (
"context"
"errors"
"fmt"
_ "image/jpeg"
@@ -27,7 +26,7 @@ var ErrNoVisionModel = errors.New("this model is missing data required for image
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
Forward(ml.Context, input.Batch) (ml.Tensor, error)
Forward(ml.Context, input.Options) (ml.Tensor, error)
Backend() ml.Backend
Config() config
@@ -95,14 +94,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
}
// New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(ctx context.Context, modelPath string, params ml.BackendParams) (Model, error) {
func New(modelPath string, params ml.BackendParams) (Model, error) {
r, err := os.Open(modelPath)
if err != nil {
return nil, err
}
defer r.Close()
b, err := ml.NewBackend(ctx, r, params)
b, err := ml.NewBackend(r, params)
if err != nil {
return nil, err
}
@@ -281,30 +280,24 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) {
if len(opts.Positions) != len(opts.Sequences) {
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences))
}
if len(batch.Positions) < 1 {
if len(opts.Positions) < 1 {
return nil, errors.New("batch size cannot be less than 1")
}
var err error
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
if err != nil {
return nil, err
}
cache := m.Config().Cache
if cache != nil {
err := cache.StartForward(ctx, batch)
err := cache.StartForward(ctx, opts)
if err != nil {
return nil, err
}
}
t, err := m.Forward(ctx, batch)
t, err := m.Forward(ctx, opts)
if err != nil {
return nil, err
}

View File

@@ -163,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) {
type notTextProcessorModel struct{}
func (notTextProcessorModel) Forward(ml.Context, input.Batch) (ml.Tensor, error) {
func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) {
panic("unimplemented")
}

View File

@@ -168,18 +168,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
if len(m.Layers) == gemma27BLayerCount {
@@ -206,7 +211,8 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
// final logit softcap
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.Options.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
return hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap)), nil
hiddenState = hiddenState.Scale(ctx, float64(m.Options.finalLogitSoftcap))
return hiddenState.Rows(ctx, outputs), nil
}
func init() {

View File

@@ -139,18 +139,23 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return result, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}
func init() {

View File

@@ -171,13 +171,13 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
return hiddenState.Add(ctx, residual)
}
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize)))
// set image embeddings
var except []int
for _, image := range batch.Multimodal {
for _, image := range opts.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))

View File

@@ -139,18 +139,23 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)

View File

@@ -0,0 +1,207 @@
package llama
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
fmt.Println("Model Parameters:")
fmt.Printf(" model_type: %q\n", "gpt2")
fmt.Printf(" vocab_size: %d\n", len(c.Strings("tokenizer.ggml.tokens")))
fmt.Printf(" hidden_size: %d\n", m.Options.hiddenSize)
fmt.Printf(" num_hidden_layers: %d\n", c.Uint("block_count"))
fmt.Printf(" num_attention_heads: %d\n", m.Options.numHeads)
fmt.Printf(" num_key_value_heads: %d\n", m.Options.numKVHeads)
fmt.Printf(" rms_norm_eps: %g\n", m.Options.eps)
fmt.Printf(" rope_theta: %g\n", m.Options.ropeBase)
fmt.Printf(" bos_token_id: %d\n", c.Uint("tokenizer.ggml.bos_token_id"))
fmt.Printf(" eos_token_id: %d\n", c.Uint("tokenizer.ggml.eos_token_id"))
fmt.Printf(" pad_token_id: %d\n", c.Uint("tokenizer.ggml.pad_token_id", 0))
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.o_proj"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
// Get head dimension - use explicit value if available, otherwise calculate
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
// Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
// Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output for final projection
outputDim := headDim * opts.numHeads
kqv = kqv.Reshape(ctx, outputDim, batchSize)
// Apply output projection
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"mlp.up_proj"`
Down *nn.Linear `gguf:"mlp.down_proj"`
Gate *nn.Linear `gguf:"mlp.gate_proj"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// Get token embeddings
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
// Apply output normalization
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
// Apply output projection
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("mistral", New)
}

View File

@@ -135,27 +135,32 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
return inputs, nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
var crossAttentionStates ml.Tensor
if len(batch.Multimodal) > 0 {
images := batch.Multimodal[len(batch.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(opts.Multimodal) > 0 {
images := opts.Multimodal[len(opts.Multimodal)-1].Multimodal.([]ml.Tensor)
if len(images) > 0 {
crossAttentionStates = images[len(images)-1]
}
}
positions, err := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
outputs, err := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -4,5 +4,6 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -263,6 +263,10 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
})
}
// tekken loads the Tekken tokenizer for testing
func tekken(t testing.TB) TextProcessor {
t.Helper()
// Load tokenizer config from mistral-small
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
configFile, err := os.Open(tokenizerConfigPath)
if err != nil {
t.Fatal(err)
}
defer configFile.Close()
var config struct {
AddBosToken bool `json:"add_bos_token"`
AddEosToken bool `json:"add_eos_token"`
BosToken struct {
Content string `json:"content"`
} `json:"bos_token"`
EosToken struct {
Content string `json:"content"`
} `json:"eos_token"`
}
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
t.Fatal(err)
}
// Load tokenizer.json which contains the vocabulary and other settings
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
tokenizerFile, err := os.Open(tokenizerJsonPath)
if err != nil {
t.Fatal(err)
}
defer tokenizerFile.Close()
var tokenizerData struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
AddedTokens []struct {
Id int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
Behavior string `json:"behavior"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
}
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
t.Fatal(err)
}
// Extract the pattern from pre_tokenizer if available
var pattern string
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
}
// Combine regular vocab and added tokens
vocab := tokenizerData.Model.Vocab
// Add special tokens from added_tokens
for _, token := range tokenizerData.AddedTokens {
vocab[token.Content] = token.Id
}
// Create vocabulary arrays
maxId := int32(-1)
for _, id := range vocab {
if id > maxId {
maxId = id
}
}
vocabSize := int(maxId + 1)
types := make([]uint32, vocabSize)
tokens := make([]string, vocabSize)
scores := make([]float32, vocabSize)
for token, id := range vocab {
tokens[id] = token
types[id] = TOKEN_TYPE_NORMAL
// Assign appropriate token types for special tokens
if token == "<s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "</s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "[INST]" || token == "[/INST]" {
types[id] = TOKEN_TYPE_CONTROL
}
}
// In Tekken, we don't need to load merges separately as they're part of the model
var merges []string
// Create vocabulary object
vocabObj := &Vocabulary{
Values: tokens,
Types: types,
Scores: scores,
Merges: merges,
BOS: vocab[config.BosToken.Content],
EOS: vocab[config.EosToken.Content],
AddBOS: config.AddBosToken,
AddEOS: config.AddEosToken,
}
// Use pattern from tokenizer.json if available
if pattern != "" {
// Ensure pattern has proper escaping for Go regexp
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
return NewBytePairEncoding(pattern, vocabObj)
}
// Fallback pattern if not found
return NewBytePairEncoding(
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
vocabObj,
)
}
func TestTekken(t *testing.T) {
// Skip if the test data isn't available
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
t.Skip("Mistral-small test data not available")
}
tokenizer := tekken(t)
t.Run("whitespace_handling", func(t *testing.T) {
t.Parallel()
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
cases := []struct {
input string
expected string
}{
{" hello", " hello"},
{"hello ", "hello "},
{"hello world", "hello world"},
{" hello world ", " hello world "},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
if decoded != tc.expected {
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
}
}
})
t.Run("chat_templates", func(t *testing.T) {
t.Parallel()
// Test the Tekken chat template format which doesn't have spaces after special tokens
templates := []struct {
input string
expectSpace bool // whether we expect a space after special tokens
}{
{"<s>[INST]user message[/INST]", false},
{"<s>[INST] user message[/INST]", true},
{"<s>[INST]user message [/INST]", true},
}
for _, tc := range templates {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
// Check if there's a space after special tokens
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
if hasSpaceAfterINST != tc.expectSpace {
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
hasSpaceAfterINST, tc.expectSpace, tc.input)
}
}
})
t.Run("special_tokens", func(t *testing.T) {
t.Parallel()
// Test how Tekken handles special tokens
cases := []struct {
input string
expected []string // We'll check if these tokens are in the decoded output
}{
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
for _, expected := range tc.expected {
if !strings.Contains(decoded, expected) {
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
}
}
}
})
t.Run("vocabulary_coverage", func(t *testing.T) {
t.Parallel()
// Tekken has a larger vocabulary, so test coverage of various token types
samples := []string{
"Hello world!",
"This is a test of the Tekken tokenizer.",
"It has a considerably larger vocabulary size.",
"Special characters: !@#$%^&*()",
"Numbers: 1234567890",
"Multiple languages: こんにちは 你好 안녕하세요",
"Code snippets: def function(): return True",
}
for _, sample := range samples {
ids, err := tokenizer.Encode(sample, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", sample, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
continue
}
if decoded != sample {
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
}
}
})
t.Run("splitting_behavior", func(t *testing.T) {
t.Parallel()
// Test the splitting behavior which might differ from SentencePiece
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"user message": {"user", " message"},
"[INST]hello": {"[INST]", "hello"},
"hello[/INST]": {"hello", "[/INST]"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
}
}
})
t.Run("full_chat_sequence", func(t *testing.T) {
t.Parallel()
// Test a complete chat sequence with Tekken's format
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
ids, err := tokenizer.Encode(chatSequence, false)
if err != nil {
t.Fatalf("Failed to encode chat sequence: %v", err)
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
}
// In Tekken, the whitespace shouldn't be added after special tokens
if strings.Contains(decoded, "[INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
}
if strings.Contains(decoded, "[/INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))

View File

@@ -31,10 +31,8 @@ type InputCache struct {
cache kvcache.Cache
}
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, batchSize int, multiUserCache bool) (*InputCache, error) {
numCtx := kvSize / int32(numSlots)
if numCtx < 1 {
func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots int, multiUserCache bool) (*InputCache, error) {
if kvSize/int32(numSlots) < 1 {
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
}
@@ -46,11 +44,11 @@ func NewInputCache(model model.Model, kvCacheType string, kvSize int32, numSlots
cache := model.Config().Cache
if cache != nil {
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), numSlots, int(numCtx), batchSize)
cache.Init(model.Backend(), kvCacheTypeFromStr(kvCacheType), kvSize)
}
return &InputCache{
numCtx: numCtx,
numCtx: kvSize / int32(numSlots),
enabled: cache != nil,
slots: slots,
multiUserCache: multiUserCache,
@@ -91,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -109,6 +107,11 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []inp
return nil, nil, err
}
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -297,131 +297,3 @@ func TestShiftDiscard(t *testing.T) {
})
}
}
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -115,9 +115,6 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs))
}
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -182,6 +179,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
return nil, nil, err
}
for _, t := range tokens {
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
fmt.Println("token", t, "decoded", decoded)
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
}
@@ -348,8 +349,7 @@ func (s *Server) processBatch() error {
}
defer s.mu.Unlock()
var batchInputs []int32
var batch input.Batch
var options input.Options
for i, seq := range s.seqs {
if seq == nil {
@@ -370,6 +370,17 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
}
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
@@ -382,31 +393,17 @@ func (s *Server) processBatch() error {
break
}
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
}
batchInputs = append(batchInputs, inp.Token)
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: inp.Multimodal})
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
}
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
options.Sequences = append(options.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
seq.iBatch = len(options.Outputs)
if j+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -414,14 +411,14 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if len(batchInputs) == 0 {
if len(options.Inputs) == 0 {
return nil
}
ctx := s.model.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
modelOutput, err := model.Forward(ctx, s.model, options)
if err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -461,7 +458,7 @@ func (s *Server) processBatch() error {
}
// sample a token
vocabSize := len(logits) / len(batch.Outputs)
vocabSize := len(logits) / len(options.Outputs)
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
@@ -597,7 +594,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
@@ -678,7 +675,6 @@ func (m *multiLPath) String() string {
}
func (s *Server) loadModel(
ctx context.Context,
mpath string,
params ml.BackendParams,
lpath multiLPath,
@@ -688,7 +684,7 @@ func (s *Server) loadModel(
multiUserCache bool,
) {
var err error
s.model, err = model.New(ctx, mpath, params)
s.model, err = model.New(mpath, params)
if err != nil {
panic(err)
}
@@ -700,7 +696,7 @@ func (s *Server) loadModel(
panic("loras are not yet implemented")
}
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, s.batchSize, multiUserCache)
s.cache, err = NewInputCache(s.model, kvCacheType, int32(kvSize), parallel, multiUserCache)
if err != nil {
panic(err)
}
@@ -784,9 +780,6 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
Progress: func(progress float32) {
server.progress = progress
},
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
@@ -795,13 +788,13 @@ func Execute(args []string) error {
}
server.ready.Add(1)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.loadModel(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)

View File

@@ -26,10 +26,6 @@ type Sampler struct {
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
if len(logits) == 0 {
return -1, errors.New("sample: no logits provided to sample")
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -91,13 +87,19 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
// scale and normalize the tokens in place
temperature(tokens, s.temperature)
softmax(tokens)
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
// TODO: this should fall back to greedy sampling
// or topP, topK values etc should be such that
// there are always tokens to sample from
if len(tokens) == 0 {
return token{}, errors.New("no tokens to sample from")
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
@@ -120,9 +122,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return 1
})
if math.IsNaN(float64(sum)) {
return token{}, errors.New("sample: logits sum to NaN, check model output")
}
return tokens[idx], nil
}

View File

@@ -1,7 +1,6 @@
package sample
import (
"math"
"math/rand/v2"
"testing"
)
@@ -30,29 +29,6 @@ func TestWeighted(t *testing.T) {
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
// Test very high p
logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
// Should get the token with the highest logit
want = int32(0)
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())}
sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil)
got, err = sampler.Sample(logits)
if err == nil {
t.Errorf("expected error, got %d", got)
return
}
}
func BenchmarkSample(b *testing.B) {

View File

@@ -26,16 +26,17 @@ func (h *tokenHeap) Pop() any {
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) {
func temperature(ts []token, temp float32) []token {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) {
func softmax(ts []token) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
@@ -55,6 +56,8 @@ func softmax(ts []token) {
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits
@@ -96,7 +99,6 @@ func topK(ts []token, k int) []token {
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
@@ -107,24 +109,37 @@ func topP(ts []token, p float32) []token {
for i, t := range ts {
sum += t.value
if sum > float32(p) {
return ts[:i+1]
ts = ts[:i+1]
return ts
}
}
return ts
}
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
// minP limits tokens to those with cumulative probability p
func minP(ts []token, p float32) []token {
maxProb := ts[0].value
if p == 1.0 {
return ts
}
threshold := maxProb * p
for i, t := range ts {
if t.value < threshold {
return ts[:i]
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
}

View File

@@ -34,22 +34,17 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
tokens := toTokens(input)
temperature(tokens, 0.5)
got := temperature(toTokens(input), 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, tokens)
compareLogits(t, "temperature(0.5)", want, got)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
got = temperature(toTokens(input), 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, tokens)
compareLogits(t, "temperature(1)", want, got)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
got = temperature(toTokens(input), 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, tokens)
compareLogits(t, "temperature(0)", want, got)
}
func TestSoftmax(t *testing.T) {
@@ -95,17 +90,16 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tokens := toTokens(tt.input)
softmax(tokens)
got := softmax(toTokens(tt.input))
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, tokens)
compareLogits(t, tt.name, tt.expected, got)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range tokens {
for _, token := range got {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -120,44 +114,38 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
tokens = topK(tokens, 5)
if len(tokens) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
// Test k=5
got := topK(toTokens(input), 5)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, tokens)
compareLogits(t, "topK(3)", want, got)
tokens = toTokens(input)
tokens = topK(tokens, 20)
if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
}
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, -1)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, tokens)
compareLogits(t, "topK(-1)", want, got)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
tokens = toTokens(input)
tokens = topK(tokens, 0)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
}
compareLogits(t, "topK(-1)", want, got)
}
func TestTopP(t *testing.T) {
@@ -165,134 +153,50 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
softmax(tokens)
tokens = softmax(tokens)
tokens = topK(tokens, 20)
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
// Test with normal p value
got = topP(tokens, 0.95)
// Then apply topP
got := topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
// Test with zero p value
got = topP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != 1 {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
tokens = topK(tokens, 20)
softmax(tokens)
tokens = softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Then apply minP
got := minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
}
}
func TestSortLogits(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
tokens := toTokens(input)
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
got := minP(tokens, 1.0)
if len(got) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
}
// Test with normal p value
got = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
}
// Test with zero p value
got = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(got) != len(tokens) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
for i := 1; i < len(tokens); i++ {
if tokens[i].value > tokens[i-1].value {
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
i, tokens[i].value, tokens[i-1].value)
}
}
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
compareLogits(t, "sortLogits", want, tokens)
}
func BenchmarkTransforms(b *testing.B) {
@@ -327,7 +231,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 10)
topK(tokensCopy, 10)
}
})
@@ -335,7 +239,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topP(tokensCopy, 0.9)
topP(tokensCopy, 0.9)
}
})
@@ -343,7 +247,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = minP(tokensCopy, 0.2)
minP(tokensCopy, 0.2)
}
})
@@ -351,7 +255,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
tokens = topK(tokensCopy, 200000)
topK(tokensCopy, 200000)
}
})
}

View File

@@ -37,6 +37,7 @@ import (
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/internal/backoff"
"github.com/ollama/ollama/server/internal/internal/names"
_ "embed"
@@ -59,11 +60,6 @@ var (
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
ErrCached = errors.New("cached")
// ErrIncomplete is returned by [Registry.Pull] when a model pull was
// incomplete due to one or more layer download failures. Users that
// want specific errors should use [WithTrace].
ErrIncomplete = errors.New("incomplete")
)
// Defaults
@@ -217,6 +213,12 @@ type Registry struct {
// request. If zero, [DefaultChunkingThreshold] is used.
ChunkingThreshold int64
// MaxChunkSize is the maximum size of a chunk to download. If zero,
// the default is [DefaultMaxChunkSize].
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
Mask string
@@ -276,19 +278,8 @@ func DefaultRegistry() (*Registry, error) {
func UserAgent() string {
buildinfo, _ := debug.ReadBuildInfo()
version := buildinfo.Main.Version
if version == "(devel)" {
// When using `go run .` the version is "(devel)". This is seen
// as an invalid version by ollama.com and so it defaults to
// "needs upgrade" for some requests, such as pulls. These
// checks can be skipped by using the special version "v0.0.0",
// so we set it to that here.
version = "v0.0.0"
}
return fmt.Sprintf("ollama/%s (%s %s) Go/%s",
version,
buildinfo.Main.Version,
runtime.GOARCH,
runtime.GOOS,
runtime.Version(),
@@ -434,14 +425,13 @@ func canRetry(err error) bool {
//
// It always calls update with a nil error.
type trackingReader struct {
l *Layer
r io.Reader
update func(l *Layer, n int64, err error)
r io.Reader
n *atomic.Int64
}
func (r *trackingReader) Read(p []byte) (n int, err error) {
n, err = r.r.Read(p)
r.update(r.l, int64(n), nil)
r.n.Add(int64(n))
return
}
@@ -457,11 +447,6 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
if err != nil {
return err
}
// TODO(bmizerany): decide if this should be considered valid. Maybe
// server-side we special case '{}' to have some special meaning? Maybe
// "archiving" a tag (which is how we reason about it in the registry
// already, just with a different twist).
if len(m.Layers) == 0 {
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
@@ -471,7 +456,11 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
// TODO(bmizerany): work to remove the need to do this
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
}
layers := m.Layers
if m.Config != nil && m.Config.Digest.IsValid() {
layers = append(layers, m.Config)
@@ -479,21 +468,20 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
// Send initial layer trace events to allow clients to have an
// understanding of work to be done before work starts.
var expected int64
t := traceFromContext(ctx)
for _, l := range layers {
skip := make([]bool, len(layers))
for i, l := range layers {
t.update(l, 0, nil)
expected += l.Size
if exists(l) {
skip[i] = true
t.update(l, l.Size, ErrCached)
}
}
var total atomic.Int64
var g errgroup.Group
g, ctx := errgroup.WithContext(ctx)
g.SetLimit(r.maxStreams())
for _, l := range layers {
info, err := c.Get(l.Digest)
if err == nil && info.Size == l.Size {
total.Add(l.Size)
t.update(l, l.Size, ErrCached)
for i, l := range layers {
if skip[i] {
continue
}
@@ -502,68 +490,77 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
t.update(l, 0, err)
continue
}
// TODO(bmizerany): fix this unbounded use of defer
defer chunked.Close()
var progress atomic.Int64
for cs, err := range r.chunksums(ctx, name, l) {
if err != nil {
// Chunksum stream was interrupted, so tell
// trace about it, and let in-flight chunk
// downloads finish. Once they finish, return
// ErrIncomplete, which is triggered by the
// fact that the total bytes received is less
// than the expected bytes.
t.update(l, 0, err)
t.update(l, progress.Load(), err)
break
}
g.Go(func() (err error) {
defer func() {
if err == nil || errors.Is(err, ErrCached) {
total.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
defer func() { t.update(l, progress.Load(), err) }()
for _, err := range backoff.Loop(ctx, 3*time.Second) {
if err != nil {
return err
}
}()
err := func() error {
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
req, err := http.NewRequestWithContext(ctx, "GET", cs.URL, nil)
if err != nil {
return err
// Count bytes towards
// progress, as they arrive, so
// that our bytes piggyback
// other chunk updates on
// completion.
//
// This tactic is enough to
// show "smooth" progress given
// the current CLI client. In
// the near future, the server
// should report download rate
// since it knows better than
// a client that is measuring
// rate based on wall-clock
// time-since-last-update.
body := &trackingReader{r: res.Body, n: &progress}
err = chunked.Put(cs.Chunk, cs.Digest, body)
if err != nil {
return err
}
return nil
}()
if !canRetry(err) {
return err
}
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", cs.Chunk.Start, cs.Chunk.End))
res, err := sendRequest(r.client(), req)
if err != nil {
return err
}
defer res.Body.Close()
// Count bytes towards progress, as they
// arrive, so that our bytes piggyback other
// chunk updates on completion.
//
// This tactic is enough to show "smooth"
// progress given the current CLI client. In
// the near future, the server should report
// download rate since it knows better than a
// client that is measuring rate based on
// wall-clock time-since-last-update.
body := &trackingReader{l: l, r: res.Body, update: t.update}
return chunked.Put(cs.Chunk, cs.Digest, body)
return nil
})
}
}
if err := g.Wait(); err != nil {
return err
}
if total.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, total.Load(), expected)
}
// store the manifest blob
md := blob.DigestFromBytes(m.Data)
if err := blob.PutBytes(c, md, m.Data); err != nil {
return err
}
// commit the manifest with a link
return c.Link(m.Name, md)
}

View File

@@ -17,7 +17,6 @@ import (
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
@@ -25,28 +24,6 @@ import (
"github.com/ollama/ollama/server/internal/testutil"
)
func ExampleRegistry_cancelOnFirstError() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if err != nil {
// Discontinue pulling layers if there is an
// error instead of continuing to pull more
// data.
cancel()
}
},
})
var r Registry
if err := r.Pull(ctx, "model"); err != nil {
// panic for demo purposes
panic(err)
}
}
func TestManifestMarshalJSON(t *testing.T) {
// All manifests should contain an "empty" config object.
var m Manifest
@@ -79,21 +56,21 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// newClient constructs a cache with predefined manifests for testing. The manifests are:
//
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
// empty: no data
// zero: no layers
// single: one layer with the contents "exists"
// multiple: two layers with the contents "exists" and "here"
// notfound: a layer that does not exist in the cache
// null: one null layer (e.g. [null])
// sizemismatch: one valid layer, and one with a size mismatch (file size is less than the reported size)
// invalid: a layer with invalid JSON data
//
// Tests that want to ensure the client does not communicate with the upstream
// registry should pass a nil handler, which will cause a panic if
// communication is attempted.
//
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blob.DiskCache) {
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
@@ -111,7 +88,7 @@ func newClient(t *testing.T, upstreamRegistry http.HandlerFunc) (*Registry, *blo
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(upstreamRegistry),
Transport: recordRoundTripper(h),
},
}
@@ -790,79 +767,3 @@ func TestUnlink(t *testing.T) {
}
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
}
}
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
}
}

View File

@@ -200,7 +200,7 @@ type params struct {
//
// Unfortunately, this API was designed to be a bit awkward. Stream is
// defined to default to true if not present, so we need a way to check
// if the client decisively set it to false. So, we use a pointer to a
// if the client decisively it to false. So, we use a pointer to a
// bool. Gross.
//
// Use [stream()] to get the correct value for this field.
@@ -280,17 +280,17 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
progress := make(map[*ollama.Layer]int64)
progressCopy := make(map[*ollama.Layer]int64, len(progress))
flushProgress := func() {
pushUpdate := func() {
defer maybeFlush()
// TODO(bmizerany): Flushing every layer in one update doesn't
// scale well. We could flush only the modified layers or track
// the full download. Needs further consideration, though it's
// fine for now.
// TODO(bmizerany): This scales poorly with more layers due to
// needing to flush out them all in one big update. We _could_
// just flush on the changed ones, or just track the whole
// download. Needs more thought. This is fine for now.
mu.Lock()
maps.Copy(progressCopy, progress)
mu.Unlock()
for l, n := range progressCopy {
for l, n := range progress {
enc.Encode(progressUpdateJSON{
Digest: l.Digest,
Total: l.Size,
@@ -298,26 +298,19 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
})
}
}
defer flushProgress()
t := time.NewTicker(1000 * time.Hour) // "unstarted" timer
t := time.NewTicker(time.Hour) // "unstarted" timer
start := sync.OnceFunc(func() {
flushProgress() // flush initial state
pushUpdate()
t.Reset(100 * time.Millisecond)
})
ctx := ollama.WithTrace(r.Context(), &ollama.Trace{
Update: func(l *ollama.Layer, n int64, err error) {
if n > 0 {
// Block flushing progress updates until every
// layer is accounted for. Clients depend on a
// complete model size to calculate progress
// correctly; if they use an incomplete total,
// progress indicators would erratically jump
// as new layers are registered.
start()
start() // flush initial state
}
mu.Lock()
progress[l] += n
progress[l] = n
mu.Unlock()
},
})
@@ -330,9 +323,9 @@ func (s *Local) handlePull(w http.ResponseWriter, r *http.Request) error {
for {
select {
case <-t.C:
flushProgress()
pushUpdate()
case err := <-done:
flushProgress()
pushUpdate()
if err != nil {
var status string
if errors.Is(err, ollama.ErrModelNotFound) {

View File

@@ -82,7 +82,7 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
for _, layer := range layers {
if s := layer.GGML.KV().ChatTemplate(); s != "" {
if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err, "template", s)
slog.Debug("template detection", "error", err)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {

View File

@@ -1,13 +0,0 @@
{{- range $i, $_ := .Messages }}
{{- $last := eq (len (slice $.Messages $i)) 1 }}
{{- if eq .Role "user" }}<start_of_turn>user
{{- if and (eq $i 1) $.System }}
{{ $.System }}
{{ end }}
{{ .Content }}<end_of_turn>
{{ else if eq .Role "assistant" }}<start_of_turn>model
{{ .Content }}<end_of_turn>
{{ end }}
{{- if $last }}<start_of_turn>model
{{ end }}
{{- end }}

View File

@@ -1,6 +0,0 @@
{
"stop": [
"<end_of_turn>"
],
"temperature": 0.1
}

View File

@@ -87,10 +87,6 @@
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct"
},
{
"template": "{{ bos_token }}\n{%- if messages[0]['role'] == 'system' -%}\n {%- if messages[0]['content'] is string -%}\n {%- set first_user_prefix = messages[0]['content'] + '\n\n' -%}\n {%- else -%}\n {%- set first_user_prefix = messages[0]['content'][0]['text'] + '\n\n' -%}\n {%- endif -%}\n {%- set loop_messages = messages[1:] -%}\n{%- else -%}\n {%- set first_user_prefix = \"\" -%}\n {%- set loop_messages = messages -%}\n{%- endif -%}\n{%- for message in loop_messages -%}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) -%}\n {{ raise_exception(\"Conversation roles must alternate user/assistant/user/assistant/...\") }}\n {%- endif -%}\n {%- if (message['role'] == 'assistant') -%}\n {%- set role = \"model\" -%}\n {%- else -%}\n {%- set role = message['role'] -%}\n {%- endif -%}\n {{ '<start_of_turn>' + role + '\n' + (first_user_prefix if loop.first else \"\") }}\n {%- if message['content'] is string -%}\n {{ message['content'] | trim }}\n {%- elif message['content'] is iterable -%}\n {%- for item in message['content'] -%}\n {%- if item['type'] == 'image' -%}\n {{ '<start_of_image>' }}\n {%- elif item['type'] == 'text' -%}\n {{ item['text'] | trim }}\n {%- endif -%}\n {%- endfor -%}\n {%- else -%}\n {{ raise_exception(\"Invalid content type\") }}\n {%- endif -%}\n {{ '<end_of_turn>\n' }}\n{%- endfor -%}\n{%- if add_generation_prompt -%}\n {{'<start_of_turn>model\n'}}\n{%- endif -%}\n",
"name": "gemma3-instruct"
},
{
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct"

View File

@@ -1,10 +0,0 @@
<start_of_turn>user
You are a helpful assistant.
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model

View File

@@ -1,4 +0,0 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model

View File

@@ -1,8 +0,0 @@
<start_of_turn>user
Hello, how are you?<end_of_turn>
<start_of_turn>model
I'm doing great. How can I help you today?<end_of_turn>
<start_of_turn>user
I'd like to show off how chat templating works!<end_of_turn>
<start_of_turn>model