Compare commits

..

2 Commits

Author SHA1 Message Date
Michael Yang
3dcc31dfac chore: remove embedded metal file 2025-12-08 13:25:14 -08:00
Michael Yang
260b165a1a fix: qwen2.5vl metal argsort 2025-12-08 13:22:35 -08:00
46 changed files with 201 additions and 13571 deletions

View File

@@ -200,8 +200,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen25VLModel{}
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
conv = &qwen3VLModel{}
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
conv = &olmoModel{}
case "BertModel":
conv = &bertModel{}
case "CohereForCausalLM":

View File

@@ -2,7 +2,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
@@ -27,26 +26,16 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
RopeScaling *struct {
Type string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
} `json:"rope_scaling"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
@@ -92,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
// The sliding window pattern is either provided as the sliding_window_pattern
// key (an int) or as the layer_types key (a list of strings).
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for i := range numBlocks {
var isLocal bool
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
isLocal = p.LayerTypes[i] == "sliding_attention"
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
isLocal = (i+1)%*p.SlidingWindowPattern != 0
}
if !yield(isLocal) {
break
}
}
})
}
if p.FinalLogitSoftcap > 0 {
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
}
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
kv["gemma3.rope.scaling.type"] = "yarn"
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
}
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:

View File

@@ -1,124 +0,0 @@
package convert
import (
"cmp"
"github.com/ollama/ollama/fs/ggml"
)
type ropeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
AttentionFactor float32 `json:"attention_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
RopeType string `json:"rope_type"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
}
type olmoModel struct {
ModelParameters
HiddenSize uint32 `json:"hidden_size"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
RopeScaling *ropeScaling `json:"rope_scaling"`
ClampKQV float32 `json:"f_clamp_kqv"`
SlidingWindow uint32 `json:"sliding_window"`
LayerTypes []string `json:"layer_types"`
}
var _ ModelConverter = (*olmoModel)(nil)
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "olmo2"
kv["olmo2.block_count"] = p.NumHiddenLayers
kv["olmo2.context_length"] = p.MaxPositionEmbeddings
kv["olmo2.embedding_length"] = p.HiddenSize
kv["olmo2.feed_forward_length"] = p.IntermediateSize
kv["olmo2.attention.head_count"] = p.NumAttentionHeads
kv["olmo2.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
if p.RopeTheta > 0 {
kv["olmo2.rope.freq_base"] = p.RopeTheta
} else {
kv["olmo2.rope.freq_base"] = float32(10000.0)
}
if p.RopeScaling != nil {
if p.RopeScaling.Factor > 0 {
kv["olmo2.rope.scaling.factor"] = p.RopeScaling.Factor
}
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
kv["olmo2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
}
if p.RopeScaling.AttentionFactor > 0 {
kv["olmo2.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
}
if p.RopeScaling.RopeType != "" {
kv["olmo2.rope.scaling.type"] = p.RopeScaling.RopeType
}
}
if p.RMSNormEPS > 0 {
kv["olmo2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
}
if p.ClampKQV > 0 {
kv["olmo2.attention.clamp_kqv"] = p.ClampKQV
}
if p.SlidingWindow > 0 {
kv["olmo2.attention.sliding_window"] = p.SlidingWindow
}
if len(p.LayerTypes) > 0 {
slidingPattern := make([]bool, len(p.LayerTypes))
for i, layerType := range p.LayerTypes {
slidingPattern[i] = (layerType == "sliding_attention")
}
kv["olmo2.attention.sliding_window_pattern"] = slidingPattern
}
return kv
}
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *olmoModel) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"model.norm", "output_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"self_attn.q_norm", "attn_q_norm",
"self_attn.k_norm", "attn_k_norm",
"post_attention_layernorm", "post_attention_norm",
"post_feedforward_layernorm", "post_ffw_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
}
}

View File

@@ -252,7 +252,6 @@ func (kv KV) OllamaEngineRequired() bool {
"deepseekocr",
"deepseek2",
"nomic-bert",
"olmo2",
}, kv.Architecture())
}

File diff suppressed because it is too large Load Diff

View File

@@ -2,7 +2,7 @@
package metal
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal; rm ggml-metal.metal"
// #cgo CXXFLAGS: -std=c++17
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_HAS_BF16 -I.. -I../../include

View File

@@ -1,4 +1,5 @@
package nn
// fast provides implementations of fast (fused) operations for increased performance.
package fast
import (
"github.com/ollama/ollama/ml"
@@ -7,7 +8,7 @@ import (
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
type fastRoPE interface {
RoPE(ctx ml.Context, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
}
// RoPE applies rotary positional embedding to tensor `t`.

View File

@@ -1,4 +1,3 @@
// Package rope provides options for RoPE
package rope
import "github.com/ollama/ollama/ml"
@@ -58,18 +57,6 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) {
}
}
func WithBetaFast(betaFast float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.BetaFast = betaFast
}
}
func WithBetaSlow(betaSlow float32) func(*Options) {
return func(opts *Options) {
opts.YaRN.BetaSlow = betaSlow
}
}
func WithMRoPE(sections []int) func(*Options) {
return func(opts *Options) {
opts.Type |= 1 << 3

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -41,12 +42,13 @@ type Options struct {
kqScale float64
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1./o.ropeScale,
func (o Options) RoPEOptions() []func(*rope.Options) {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
return []func(*rope.Options){
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
rope.WithAttentionFactor(float32(1.0/(1.0+0.1*math.Log(float64(o.ropeScale))))),
)
rope.WithAttentionFactor(attnFactor),
}
}
type Attention struct {
@@ -89,8 +91,8 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
compressedKV.Stride(1), compressedKV.Dim(1),
)
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
var attention ml.Tensor
@@ -325,7 +327,7 @@ func New(c fs.Config) (model.Model, error) {
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {

View File

@@ -6,6 +6,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
@@ -19,7 +20,7 @@ type textModel struct {
}
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
}
type textOptions struct {
@@ -37,8 +38,8 @@ func (o textOptions) headDim() int {
return o.hiddenSize / o.numHeads
}
func (o textOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
}
type textBlock struct {
@@ -82,8 +83,8 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
value := m.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, -1, attention.Dim(2))

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -21,10 +22,6 @@ type Options struct {
largeModelScaling bool
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.attnKeyLen, o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
}
type Model struct {
model.Base
model.SentencePiece
@@ -91,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -101,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -131,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
}
type MLP struct {

View File

@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
model.TextProcessor
model.SentencePiece
*VisionModel `gguf:"v"`
*TextModel
@@ -54,35 +54,24 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model") {
case "gpt2":
processor = model.NewBytePairEncoding(&vocabulary)
default:
// Previous uploads of Gemma 3 on Ollama did not have token 106
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
processor = model.NewSentencePiece(&vocabulary)
}
m := Model{
TextProcessor: processor,
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{
int32(c.Uint("tokenizer.ggml.eos_token_id")),
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
},
),
ImageProcessor: newImageProcessor(c),
VisionModel: newVisionModel(c),
TextModel: newTextModel(c),
@@ -152,16 +141,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
hiddenState = m.Output.Forward(ctx, hiddenState)
if m.TextConfig.finalLogitSoftcap > 0.0 {
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
hiddenState = hiddenState.Tanh(ctx)
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
}
return hiddenState, nil
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
return m.Output.Forward(ctx, hiddenStates), nil
}
func init() {

View File

@@ -2,12 +2,12 @@ package gemma3
import (
"math"
"slices"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -16,32 +16,8 @@ type TextConfig struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase float32
ropeLocalBase, ropeGlobalBase float32
largeModelScaling bool
slidingWindowPattern []bool
ropeBase float32
ropeType string
ropeOriginalContext int
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
finalLogitSoftcap float32
}
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
if o.ropeType == "yarn" {
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(o.ropeOriginalContext),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
}
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
}
type TextModel struct {
@@ -69,35 +45,21 @@ func newTextModel(c fs.Config) *TextModel {
m := TextModel{
Layers: make([]TextLayer, numBlocks),
TextConfig: &TextConfig{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeBase: c.Float("rope.freq_base", 1000000.0),
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
ropeType: c.String("rope.scaling.type"),
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
ropeScale: c.Float("rope.scaling.factor", 1.0),
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: 1,
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
// (8 instead of 1)
// ropeScale: c.Float("rope.scaling.factor", 1.0),
},
}
// Google's Gemma 3 release with sliding window attention does
// not use final logit softcapping, and so force it to 0.0
// TODO (jmorganca): this should ideally be set to 0.0 in the
// model configuration instead of here, as future versions of
// models may include both sliding window attention and final
// logit softcapping.
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
m.TextConfig.finalLogitSoftcap = 0.0
}
if numBlocks == gemma27BLayerCount {
m.largeModelScaling = true
}
@@ -114,31 +76,18 @@ type TextSelfAttention struct {
Output *nn.Linear `gguf:"attn_output"`
}
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
return opts.ropeLocalBase
}
// Standard Gemma3: only every n-th layer is global,
// where n = gemmaGlobalCacheCount, otherwise use
// the local rope base
if (layer+1)%gemmaGlobalCacheCount > 0 {
return opts.ropeLocalBase
}
// default to global rope base
return opts.ropeBase
}
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeBase := opts.ropeBaseForLayer(layer)
ropeBase := opts.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -149,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -162,7 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
ropeBase := m.TextConfig.ropeLocalBase
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextConfig.ropeGlobalBase
}
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
}
type TextMLP struct {
@@ -250,5 +204,6 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
}
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -94,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
ropeBase = m.ropeBaseLocal
}
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
type TextScaledWordEmbedding struct {
@@ -255,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase)
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
var key, value ml.Tensor
if !sharedKV {
key = attn.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase)
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
value = attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
@@ -329,10 +330,6 @@ func (o *TextOptions) isLocal(i int) bool {
return o.slidingWindowPattern[i]
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor {
return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX())
}
func newTextModel(c fs.Config) *TextModel {
return &TextModel{
TextLayers: make([]TextLayer, c.Uint("block_count")),

View File

@@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -51,7 +52,7 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
}
type Options struct {
@@ -69,14 +70,14 @@ type Options struct {
ropeScale float32
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale,
func (o Options) RoPEOptions() []func(*rope.Options) {
return []func(*rope.Options){
rope.WithTypeNeoX(),
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(1.),
// NOTE: ggml sets this implicitly so there's no need to set it here
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
)
// NOTE: ggml sets this implicitly so there's no need to set it here
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
}
}
func (o Options) headDim() int {
@@ -134,8 +135,8 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
}
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -19,10 +20,6 @@ type Options struct {
eps, ropeBase, ropeScale float32
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
}
type Model struct {
model.Base
model.TextProcessor
@@ -118,6 +115,7 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
@@ -128,8 +126,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -137,7 +135,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors), nil
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
}
type MLP struct {

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -32,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
if useRope {
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
}
if opts.useQKNorm {
@@ -151,10 +152,6 @@ type TextOptions struct {
attentionFloorScale float64
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
}
type TextModel struct {
Layers []TextLayer `gguf:"blk"`
@@ -239,5 +236,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].Attention.RopeFactors), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/model/input"
)
@@ -19,10 +20,6 @@ type TextOptions struct {
ropeScalingBeta float32
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -45,11 +42,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -64,7 +61,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
}
type MLP struct {

View File

@@ -16,8 +16,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
type VisionSelfAttention struct {
@@ -36,8 +36,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)

View File

@@ -8,6 +8,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
)
@@ -25,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -43,8 +44,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
}
return key, nil
@@ -205,10 +206,6 @@ type TextModelOptions struct {
crossAttentionLayers []int32
}
func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Transformer *TextDecoder `gguf:"blk"`

View File

@@ -13,7 +13,6 @@ import (
_ "github.com/ollama/ollama/model/models/mistral3"
_ "github.com/ollama/ollama/model/models/mllama"
_ "github.com/ollama/ollama/model/models/nomicbert"
_ "github.com/ollama/ollama/model/models/olmo"
_ "github.com/ollama/ollama/model/models/qwen2"
_ "github.com/ollama/ollama/model/models/qwen25vl"
_ "github.com/ollama/ollama/model/models/qwen3"

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
@@ -36,10 +37,6 @@ type Options struct {
ropeFreqBase float32
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX())
}
// Single Encoder Layer
type EncoderLayer struct {
*Attention
@@ -108,8 +105,8 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
query, key, value := chunks[0], chunks[1], chunks[2]
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)

View File

@@ -1,298 +0,0 @@
package olmo
import (
"fmt"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
const (
cacheTypeSWA = iota
cacheTypeCausal
)
type Options struct {
hiddenSize, numHeads, numKVHeads int
// headDim, ropeDim int
eps, ropeBase, ropeScale float32
originalContextLength int
attnFactor float32
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
slidingWindowPattern []bool
}
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
Options
}
func New(c fs.Config) (model.Model, error) {
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
c.Ints("tokenizer.ggml.eos_token_ids")...,
),
}
if c.String("tokenizer.ggml.model") != "gpt2" {
return nil, model.ErrUnsupportedTokenizer
}
var pretokenizers []string
if c.String("tokenizer.ggml.pre") != "default" {
pretokenizers = []string{
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
}
}
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
hiddenSize := int(c.Uint("embedding_length"))
numHeads := int(c.Uint("attention.head_count"))
numKVHeads := int(c.Uint("attention.head_count_kv"))
eps := c.Float("attention.layer_norm_rms_epsilon")
ropeBase := c.Float("rope.freq_base", 1e4)
ropeScale := c.Float("rope.scaling.factor", 1)
originalContextLength := int(c.Uint("rope.scaling.original_context_length"))
attnFactor := c.Float("rope.scaling.attn_factor", 1)
ropeType := c.String("rope.scaling.type")
ropeExtrapolation := c.Float("rope.scaling.extrapolation_factor", 1.0)
fmt.Printf("hiddenSize: %d\n", hiddenSize)
fmt.Printf("numHeads: %d\n", numHeads)
fmt.Printf("numKVHeads: %d\n", numKVHeads)
fmt.Printf("eps: %f\n", eps)
fmt.Printf("ropeBase: %f\n", ropeBase)
fmt.Printf("ropeScale: %f\n", ropeScale)
fmt.Printf("originalContextLength: %d\n", originalContextLength)
fmt.Printf("attnFactor: %f\n", attnFactor)
fmt.Printf("ropeType: %s\n", ropeType)
fmt.Printf("ropeExtrapolation: %f\n", ropeExtrapolation)
fmt.Printf("sliding_window_pattern: %v\n", c.Bools("attention.sliding_window_pattern"))
m := Model{
TextProcessor: processor,
Layers: make([]Layer, c.Uint("block_count")),
Options: Options{
hiddenSize: hiddenSize,
numHeads: numHeads,
numKVHeads: numKVHeads,
eps: eps,
ropeBase: ropeBase,
ropeScale: ropeScale,
originalContextLength: originalContextLength,
attnFactor: attnFactor,
ropeType: ropeType,
ropeExtrapolation: ropeExtrapolation,
ropeBetaFast: 32.0,
ropeBetaSlow: 1.0,
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
},
}
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim int, isSWA bool) ml.Tensor {
var ropeOpts []func(*rope.Options)
ropeOpts = append(ropeOpts, rope.WithTypeNeoX())
// Both SWA and non-SWA use beta_fast and beta_slow
// defaults
ropeOpts = append(ropeOpts,
rope.WithBetaFast(m.ropeBetaFast),
rope.WithBetaSlow(m.ropeBetaSlow),
)
// SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
// Non-SWA uses full yarn parameters
if m.originalContextLength > 0 {
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(m.originalContextLength),
)
// no yarn for swa
if isSWA {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(0),
rope.WithAttentionFactor(1.),
)
} else {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(m.ropeExtrapolation),
rope.WithAttentionFactor(m.attnFactor),
)
}
}
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / m.ropeScale
}
return nn.RoPE(ctx, states, positions, ropeDim, m.ropeBase, freqScale, ropeOpts...)
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := m.hiddenSize / m.numHeads
ropeDim := headDim
query := sa.Query.Forward(ctx, hiddenState)
// double check type
query = sa.QNorm.Forward(ctx, query, m.eps)
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
//check here
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
key := sa.Key.Forward(ctx, hiddenState)
key = sa.KNorm.Forward(ctx, key, m.eps)
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
key = m.applyRoPE(ctx, key, positions, ropeDim, isSWA)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
// check attention scaling as well
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := m.hiddenSize / m.numHeads
isSWA := m.isSWALayer(layer)
return m.applyRoPE(ctx, key, shift, ropeDim, isSWA), nil
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
SelfAttention *SelfAttention
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
MLP *MLP
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
// return hiddenState
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
// i think this should be after getting the rows?
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState.Add(ctx, residual)
}
// Olmo3 has Sliding Window Attention (SWA) 3 out of 4 layers.
func (m *Model) isSWALayer(layerIdx int) bool {
return m.Options.slidingWindowPattern[layerIdx]
}
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
cacheType := cacheTypeSWA
isSWA := m.isSWALayer(i)
if !isSWA {
cacheType = cacheTypeCausal
}
wc := m.Cache.(*kvcache.WrapperCache)
wc.SetLayerType(cacheType)
// would need to check the cache at the layer instead
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
// TODO: not sure about the index here
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{}})
}
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
// return hiddenState, nil
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("olmo2", New)
}

View File

@@ -1,568 +0,0 @@
package olmo
import (
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
typemodel "github.com/ollama/ollama/types/model"
)
var args struct {
model,
prompt string
layers int
}
func TestMain(m *testing.M) {
flag.StringVar(&args.model, "model", "", "path to model (e.g., olmo3:latest)")
flag.StringVar(&args.prompt, "prompt", "Hello, how are", "model prompt")
flag.IntVar(&args.layers, "layers", math.MaxInt, "num of gpu layers")
flag.Parse()
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
os.Exit(m.Run())
}
func blob(tb testing.TB, modelName string) string {
tb.Helper()
models := envconfig.Models()
manifest, err := os.Open(filepath.Join(models, "manifests", typemodel.ParseName(modelName).Filepath()))
if err != nil {
tb.Fatal(err)
}
defer manifest.Close()
var m struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.NewDecoder(manifest).Decode(&m); err != nil {
tb.Fatal(err)
}
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
tb.Log("using model blob", layer.Digest)
return filepath.Join(models, "blobs", strings.ReplaceAll(layer.Digest, ":", "-"))
}
}
tb.Fatal("model blob not found")
return ""
}
func loadFloatsFromBinary(filename string) ([]float32, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
if fi.Size()%4 != 0 {
return nil, fmt.Errorf("file size %d not multiple of 4", fi.Size())
}
n := int(fi.Size() / 4)
floats := make([]float32, n)
if err := binary.Read(f, binary.LittleEndian, floats); err != nil {
return nil, err
}
return floats, nil
}
func TestTokenization(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
prompt := args.prompt
if prompt == "" {
prompt = "hello, how are you?"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
t.Logf("num tokens: %d", len(tokens))
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
}
func TestAttentionForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
t.Logf("Model options: hiddenSize=%d, numHeads=%d, numKVHeads=%d",
olmoModel.hiddenSize, olmoModel.numHeads, olmoModel.numKVHeads)
t.Logf("Layer 0 attention: %+v", olmoModel.Layers[0].SelfAttention)
ctx := m.Backend().NewContext()
// Create test hidden states: (hiddenSize, batchSize)
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0 // Simple test values
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
// Test attention forward (without cache for simplicity)
attentionBlock := olmoModel.Layers[0].SelfAttention
isSWA := olmoModel.isSWALayer(0)
t.Logf("Layer 0 isSWA: %v", isSWA)
result := attentionBlock.Forward(ctx, hiddenStates, positions, nil, olmoModel, isSWA)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Attention result shape: %v dtype: %v", result.Shape(), result.DType())
// Optionally dump to file
// if err := os.WriteFile("/tmp/olmo_attention_output.bin", result.Bytes(), 0644); err != nil {
// t.Fatal(err)
// }
}
func TestMLPForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
ctx := m.Backend().NewContext()
// Create test hidden states
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
mlpBlock := olmoModel.Layers[0].MLP
result := mlpBlock.Forward(ctx, hiddenStates, olmoModel)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("MLP result shape: %v dtype: %v", result.Shape(), result.DType())
// Parse result bytes to float32
resultBytes := result.Bytes()
resultFloats := make([]float32, len(resultBytes)/4)
for i := range resultFloats {
bits := binary.LittleEndian.Uint32(resultBytes[i*4 : (i+1)*4])
resultFloats[i] = math.Float32frombits(bits)
}
// Compute statistics
var minVal, maxVal, sum float32
minVal = resultFloats[0]
maxVal = resultFloats[0]
for _, v := range resultFloats {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
sum += v
}
mean := sum / float32(len(resultFloats))
// Build readable output
var sb strings.Builder
sb.WriteString("# MLP Forward Output\n\n")
sb.WriteString(fmt.Sprintf("# Input Shape: [%d, %d] (hiddenSize, batchSize)\n", hiddenSize, batchSize))
sb.WriteString(fmt.Sprintf("# Output Shape: %v\n", result.Shape()))
sb.WriteString(fmt.Sprintf("# DType: %v\n", result.DType()))
sb.WriteString(fmt.Sprintf("# Layer: 0\n\n"))
sb.WriteString("## Statistics\n\n")
sb.WriteString(fmt.Sprintf(" Total elements: %d\n", len(resultFloats)))
sb.WriteString(fmt.Sprintf(" Min: %v\n", minVal))
sb.WriteString(fmt.Sprintf(" Max: %v\n", maxVal))
sb.WriteString(fmt.Sprintf(" Mean: %v\n\n", mean))
sb.WriteString("## Input Hidden States (first 20 values)\n\n")
sb.WriteString(" [")
for i := 0; i < min(20, len(hsFloats)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", hsFloats[i]))
}
sb.WriteString("]\n\n")
sb.WriteString("## Output Values\n\n")
// Per-position output (each position in batch)
for pos := 0; pos < batchSize; pos++ {
sb.WriteString(fmt.Sprintf("Position %d (hiddenSize=%d values):\n", pos, hiddenSize))
// Extract values for this position
posStart := pos * hiddenSize
posEnd := posStart + hiddenSize
if posEnd > len(resultFloats) {
posEnd = len(resultFloats)
}
posValues := resultFloats[posStart:posEnd]
// Full tensor values
sb.WriteString(" [")
for i, v := range posValues {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", v))
}
sb.WriteString("]\n\n")
}
// Save to file
if err := os.WriteFile("/tmp/olmo_mlp_forward.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.txt")
// Also save binary
if err := os.WriteFile("/tmp/olmo_mlp_forward.bin", resultBytes, 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.bin")
// Print summary to console
fmt.Println(sb.String())
}
func TestFullForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
ctx := m.Backend().NewContext()
prompt := args.prompt
if prompt == "" {
prompt = "Hello, how are you?"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
seqLen := len(tokens)
inputsTensor := ctx.Input().FromInts(tokens, seqLen)
positions := make([]int32, seqLen)
sequences := make([]int, seqLen)
for i := range tokens {
positions[i] = int32(i)
sequences[i] = 0
}
// Output ALL positions
outputIndices := make([]int32, seqLen)
for i := range outputIndices {
outputIndices[i] = int32(i)
}
outputs := ctx.Input().FromInts(outputIndices, seqLen)
batch := input.Batch{
Inputs: inputsTensor,
Positions: positions,
Sequences: sequences,
Outputs: outputs,
}
// Initialize cache
if cache := m.Config().Cache; cache != nil {
cache.Init(m.Backend(), ml.DTypeF16, 1, 4096, seqLen)
}
result, err := model.Forward(ctx, m, batch)
if err != nil {
t.Fatal(err)
}
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Forward pass completed, result shape: %v", result.Shape())
// Dump logits to binary file
if err := os.WriteFile("/tmp/olmo_logits.bin", result.Bytes(), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.bin")
// Parse logits from bytes for detailed analysis
logitsBytes := result.Bytes()
vocabSize := result.Shape()[0]
// Read float32 values - shape is (vocab_size, seq_len)
allLogits := make([]float32, len(logitsBytes)/4)
for i := range allLogits {
bits := binary.LittleEndian.Uint32(logitsBytes[i*4 : (i+1)*4])
allLogits[i] = math.Float32frombits(bits)
}
// Create detailed text dump matching Python format
var sb strings.Builder
sb.WriteString("# Full Forward Logits\n\n")
sb.WriteString(fmt.Sprintf("# Shape: [1, %d, %d]\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Layout: (batch=1, seq_len=%d, vocab_size=%d)\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Prompt: '%s'\n", prompt))
sb.WriteString(fmt.Sprintf("# Tokens: %v\n\n", tokens))
type logitPair struct {
tokenID int
value float32
}
// Process each position
for pos := 0; pos < seqLen; pos++ {
// Extract logits for this position
// Shape is (vocab_size, seq_len), so logits[v*seqLen + pos] gives logit for vocab v at position pos
posLogits := make([]float32, vocabSize)
for v := 0; v < vocabSize; v++ {
posLogits[v] = allLogits[v*seqLen+pos]
}
// Find top 10 logits
pairs := make([]logitPair, len(posLogits))
for i, v := range posLogits {
pairs[i] = logitPair{tokenID: i, value: v}
}
// Sort by value descending (simple bubble sort for small top-k)
for i := 0; i < min(10, len(pairs)); i++ {
for j := i + 1; j < len(pairs); j++ {
if pairs[j].value > pairs[i].value {
pairs[i], pairs[j] = pairs[j], pairs[i]
}
}
}
tokenStr, _ := tp.Decode([]int32{tokens[pos]})
sb.WriteString(fmt.Sprintf("Position %d (token_id=%d, token='%s'):\n", pos, tokens[pos], tokenStr))
sb.WriteString(" Top 10 logits:\n")
for i := 0; i < min(10, len(pairs)); i++ {
tokStr, _ := tp.Decode([]int32{int32(pairs[i].tokenID)})
// Pad token string to 20 chars for alignment
paddedTok := fmt.Sprintf("%-20s", fmt.Sprintf("'%s'", tokStr))
sb.WriteString(fmt.Sprintf(" %d. token_id=%6d (%s): %f\n", i+1, pairs[i].tokenID, paddedTok, pairs[i].value))
}
// First 20 logits
sb.WriteString(" Full logits (first 20): [")
for i := 0; i < min(20, len(posLogits)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n")
// Last 20 logits
sb.WriteString(" Full logits (last 20): [")
start := max(0, len(posLogits)-20)
for i := start; i < len(posLogits); i++ {
if i > start {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n\n")
}
if err := os.WriteFile("/tmp/olmo_logits.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.txt")
// Print to console as well
fmt.Println(sb.String())
}
func TestRoPE(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
// Test RoPE on a simple tensor
headDim := olmoModel.hiddenSize / olmoModel.numHeads
batchSize := 4
numHeads := olmoModel.numHeads
t.Logf("headDim: %d, numHeads: %d", headDim, numHeads)
t.Logf("ropeBase: %f, ropeScale: %f, originalContextLength: %d",
olmoModel.ropeBase, olmoModel.ropeScale, olmoModel.originalContextLength)
// Create test query tensor: (headDim, numHeads, batchSize)
queryFloats := make([]float32, headDim*numHeads*batchSize)
for i := range queryFloats {
queryFloats[i] = float32(i%100) / 100.0
}
// Test 1: Dump initial query values (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
t.Logf("query shape: %v", query.Shape())
query = query.Contiguous(ctx)
ctx.Forward(query).Compute(query)
dump := ml.Dump(ctx, query, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query BEFORE RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\n\n", query.Shape(), query.DType())
if err := os.WriteFile("/tmp/olmo_query_before_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_before_rope.bin", query.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_before_rope.txt and .bin")
}
// Test 2: SWA RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultSWA := olmoModel.applyRoPE(ctx, query, positions, headDim, true)
resultSWA = resultSWA.Contiguous(ctx)
ctx.Forward(resultSWA).Compute(resultSWA)
t.Logf("SWA RoPE result shape: %v", resultSWA.Shape())
dump := ml.Dump(ctx, resultSWA, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER SWA RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: 1.0 (SWA)\n\n", resultSWA.Shape(), resultSWA.DType())
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.bin", resultSWA.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_swa_rope.txt and .bin")
}
// Test 3: Global (non-SWA) RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultGlobal := olmoModel.applyRoPE(ctx, query, positions, headDim, false)
resultGlobal = resultGlobal.Contiguous(ctx)
ctx.Forward(resultGlobal).Compute(resultGlobal)
t.Logf("Global RoPE result shape: %v", resultGlobal.Shape())
dump := ml.Dump(ctx, resultGlobal, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER Global RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: %f (Global, 1/ropeScale)\n\n",
resultGlobal.Shape(), resultGlobal.DType(), 1.0/olmoModel.ropeScale)
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.bin", resultGlobal.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_global_rope.txt and .bin")
}
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -21,10 +22,6 @@ type Options struct {
eps, ropeBase, ropeScale float32
}
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
@@ -35,6 +32,7 @@ type Attention struct {
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
ropeDim := cmp.Or(opts.ropeDim, headDim)
query := attn.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
@@ -45,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
value := attn.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
@@ -125,7 +123,8 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
}
func New(c fs.Config) (model.Model, error) {

View File

@@ -7,6 +7,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model/input"
)
@@ -17,13 +18,6 @@ type TextOptions struct {
eps, ropeBase, ropeScale float32
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithTypeNeoX(),
)
}
type TextModel struct {
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -66,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -84,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
// Shift applies rotary position embeddings to the key tensor for causal attention caching
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
}
// MLP implements the feed-forward network component with SwiGLU activation

View File

@@ -18,8 +18,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
@@ -67,8 +67,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
// Scale factor for scaled dot-product attention
scale := 1.0 / math.Sqrt(float64(opts.headDim))

View File

@@ -9,6 +9,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
@@ -45,7 +46,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
rope.WithAttentionFactor(attnFactor),
)
}
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
}
type Attention struct {

View File

@@ -195,7 +195,7 @@ func New(c fs.Config) (model.Model, error) {
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
m.positionCache = nil
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
return m.Options.applyRotaryPositionEmbeddings(ctx, key, positions), nil
return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil
})
return &m, nil
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/fast"
"github.com/ollama/ollama/ml/nn/rope"
"github.com/ollama/ollama/model"
)
@@ -34,8 +35,8 @@ func (o TextOptions) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
rope.WithInterleaveMRoPE(o.mropeSections),
)
}
@@ -63,8 +64,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)

View File

@@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
}
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
}
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
query := sa.Query.Forward(ctx, hiddenStates)
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
key := sa.Key.Forward(ctx, hiddenStates)
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
value := sa.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))

View File

@@ -1,469 +0,0 @@
package parsers
import (
"context"
"fmt"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type olmo3ParserState int
const (
olmo3StateContent olmo3ParserState = iota
olmo3StateToolCalls
olmo3StateToolCallsDone
)
const (
olmo3FuncCallsOpenTag = "<function_calls>"
olmo3FuncCallsCloseTag = "</function_calls>"
)
type Olmo3Parser struct {
state olmo3ParserState
buffer strings.Builder
}
func (p *Olmo3Parser) HasToolSupport() bool {
return true
}
func (p *Olmo3Parser) HasThinkingSupport() bool {
return false
}
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.state = olmo3StateContent
return tools
}
type olmo3ParserEvent interface {
isOlmo3ParserEvent()
}
type olmo3ParserEventContent struct {
content string
}
type olmo3ParserEventToolCalls struct {
calls []api.ToolCall
}
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
if done {
// Drain any remaining content
bufStr := p.buffer.String()
p.buffer.Reset()
if p.state == olmo3StateContent && len(bufStr) > 0 {
return bufStr, "", nil, nil
}
return "", "", nil, nil
}
events := p.parseEvents()
var contentSb strings.Builder
var allCalls []api.ToolCall
for _, event := range events {
switch event := event.(type) {
case olmo3ParserEventContent:
contentSb.WriteString(event.content)
case olmo3ParserEventToolCalls:
allCalls = append(allCalls, event.calls...)
}
}
return contentSb.String(), "", allCalls, nil
}
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
var all []olmo3ParserEvent
keepLooping := true
for keepLooping {
var events []olmo3ParserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
var events []olmo3ParserEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case olmo3StateContent:
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
// Found <function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
content := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCalls
if len(content) > 0 {
events = append(events, olmo3ParserEventContent{content: content})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
// Partial <function_calls> tag - withhold ambiguous content
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, olmo3ParserEventContent{content: unambiguous})
}
return events, false
} else {
// Regular content - emit all
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
case olmo3StateToolCalls:
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
// Found </function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
toolCallsStr := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCallsDone
// Parse the function calls
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
if err != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
} else if len(calls) > 0 {
events = append(events, olmo3ParserEventToolCalls{calls: calls})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
// Partial </function_calls> tag - wait for more
return events, false
}
// Still collecting tool calls, wait for close tag
return events, false
case olmo3StateToolCallsDone:
// After tool calls, emit remaining content
p.buffer.Reset()
p.state = olmo3StateContent
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
return events, false
}
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
// func_name(arg1="value1", arg2=123)
// Multiple calls are separated by newlines
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
var calls []api.ToolCall
s = strings.TrimSpace(s)
if s == "" {
return calls, nil
}
// Split by newlines for multiple function calls
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
call, err := parseOlmo3SingleFunctionCall(line)
if err != nil {
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
}
calls = append(calls, call)
}
return calls, nil
}
// Regex to match function call: func_name(args)
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
// Regex to match a single argument: key=value
// Value can be: "string", 'string', number, true, false, null, or nested structures
var argRegex = regexp.MustCompile(`^(\w+)=(.+)$`)
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
matches := funcCallRegex.FindStringSubmatch(s)
if matches == nil {
return api.ToolCall{}, fmt.Errorf("invalid function call format")
}
funcName := matches[1]
argsStr := matches[2]
args, err := parseOlmo3Arguments(argsStr)
if err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
}, nil
}
// parseOlmo3Arguments parses comma-separated key=value pairs
// Handles nested parentheses, brackets, braces, and quoted strings
func parseOlmo3Arguments(s string) (map[string]any, error) {
args := make(map[string]any)
s = strings.TrimSpace(s)
if s == "" {
return args, nil
}
// Split by commas, but respect nested structures and quotes
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find the first = sign
eqIdx := strings.Index(part, "=")
if eqIdx == -1 {
return nil, fmt.Errorf("invalid argument format: %s", part)
}
key := strings.TrimSpace(part[:eqIdx])
valueStr := strings.TrimSpace(part[eqIdx+1:])
value, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
}
args[key] = value
}
return args, nil
}
// splitArguments splits arguments by commas, respecting quotes and nested structures
func splitArguments(s string) []string {
var parts []string
var current strings.Builder
depth := 0
inString := false
stringChar := byte(0)
escaped := false
for i := 0; i < len(s); i++ {
c := s[i]
if escaped {
current.WriteByte(c)
escaped = false
continue
}
if c == '\\' && inString {
current.WriteByte(c)
escaped = true
continue
}
if (c == '"' || c == '\'') && !inString {
inString = true
stringChar = c
current.WriteByte(c)
continue
}
if c == stringChar && inString {
inString = false
stringChar = 0
current.WriteByte(c)
continue
}
if !inString {
switch c {
case '(', '[', '{':
depth++
current.WriteByte(c)
case ')', ']', '}':
depth--
current.WriteByte(c)
case ',':
if depth == 0 {
parts = append(parts, current.String())
current.Reset()
continue
}
current.WriteByte(c)
default:
current.WriteByte(c)
}
} else {
current.WriteByte(c)
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
func parseOlmo3Value(s string) (any, error) {
s = strings.TrimSpace(s)
// Check for quoted string
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
// Remove quotes and unescape
inner := s[1 : len(s)-1]
return unescapeString(inner), nil
}
// Check for boolean
if s == "true" || s == "True" {
return true, nil
}
if s == "false" || s == "False" {
return false, nil
}
// Check for null/None
if s == "null" || s == "None" || s == "nil" {
return nil, nil
}
// Check for number
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i, nil
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, nil
}
// Check for array [...]
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
return parseOlmo3Array(s[1 : len(s)-1])
}
// Check for object {...}
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
return parseOlmo3Object(s[1 : len(s)-1])
}
// Default to string without quotes
return s, nil
}
func parseOlmo3Array(s string) ([]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return []any{}, nil
}
parts := splitArguments(s)
var arr []any
for _, part := range parts {
val, err := parseOlmo3Value(part)
if err != nil {
return nil, err
}
arr = append(arr, val)
}
return arr, nil
}
func parseOlmo3Object(s string) (map[string]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return map[string]any{}, nil
}
// Objects use key: value or "key": value format
obj := make(map[string]any)
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find colon separator
colonIdx := strings.Index(part, ":")
if colonIdx == -1 {
return nil, fmt.Errorf("invalid object entry: %s", part)
}
keyStr := strings.TrimSpace(part[:colonIdx])
valueStr := strings.TrimSpace(part[colonIdx+1:])
// Remove quotes from key if present
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
keyStr = keyStr[1 : len(keyStr)-1]
}
val, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
}
obj[keyStr] = val
}
return obj, nil
}
func unescapeString(s string) string {
// Handle common escape sequences
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
s = strings.ReplaceAll(s, `\"`, `"`)
s = strings.ReplaceAll(s, `\'`, `'`)
s = strings.ReplaceAll(s, `\n`, "\n")
s = strings.ReplaceAll(s, `\t`, "\t")
s = strings.ReplaceAll(s, `\r`, "\r")
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
return s
}

View File

@@ -1,483 +0,0 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
}{
{
name: "simple content",
input: "Hello, how can I help you?",
expectedContent: "Hello, how can I help you?",
},
{
name: "simple tool call",
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
},
},
{
name: "content then tool call",
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call with multiple arguments",
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
"date": "2024-01-15",
},
},
},
},
},
{
name: "multiple tool calls",
input: `<function_calls>get_weather(location="San Francisco")
get_weather(location="New York")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{
name: "tool call with numeric argument",
input: `<function_calls>set_temperature(value=72)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temperature",
Arguments: map[string]any{"value": int64(72)},
},
},
},
},
{
name: "tool call with float argument",
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_price",
Arguments: map[string]any{"amount": 19.99},
},
},
},
},
{
name: "tool call with boolean argument",
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "toggle_setting",
Arguments: map[string]any{"enabled": true},
},
},
},
},
{
name: "tool call with null argument",
input: `<function_calls>clear_value(field=null)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "clear_value",
Arguments: map[string]any{"field": nil},
},
},
},
},
{
name: "tool call with array argument",
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_items",
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
},
},
},
},
{
name: "tool call with dict argument",
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "update_config",
Arguments: map[string]any{
"settings": map[string]any{
"theme": "dark",
"fontSize": int64(14),
},
},
},
},
},
},
{
name: "tool call with nested dict",
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_request",
Arguments: map[string]any{
"data": map[string]any{
"user": map[string]any{
"name": "John",
"age": int64(30),
},
"active": true,
},
},
},
},
},
},
{
name: "tool call with no arguments",
input: `<function_calls>get_current_time()</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_time",
Arguments: map[string]any{},
},
},
},
},
{
name: "tool call with single quotes",
input: `<function_calls>search(query='hello world')</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": "hello world"},
},
},
},
},
{
name: "tool call with escaped quotes",
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": `say "hello"`},
},
},
},
},
{
name: "tool call with mixed argument types",
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_user",
Arguments: map[string]any{
"name": "John",
"age": int64(30),
"active": true,
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
content, thinking, calls, err := p.Add(tt.input, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Drain remaining content
finalContent, finalThinking, finalCalls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
content += finalContent
thinking += finalThinking
calls = append(calls, finalCalls...)
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedCalls []api.ToolCall
}{
{
name: "streaming content",
chunks: []string{"Hello, ", "how ", "can I help?"},
expectedContent: "Hello, how can I help?",
},
{
name: "streaming tool call",
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "streaming content then tool call",
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
expectedContent: "Let me check.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call tag split across chunks",
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: map[string]any{},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
var allContent string
var allCalls []api.ToolCall
for _, chunk := range tt.chunks {
content, _, calls, err := p.Add(chunk, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
}
// Drain
content, _, calls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
p := &Olmo3Parser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
p := &Olmo3Parser{}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
}
func TestParseOlmo3FunctionCalls(t *testing.T) {
tests := []struct {
name string
input string
expected []api.ToolCall
wantErr bool
}{
{
name: "simple call",
input: `get_weather(location="SF")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "multiple args",
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "send_email",
Arguments: map[string]any{
"to": "user@example.com",
"subject": "Hello",
"body": "Test message",
},
},
},
},
},
{
name: "multiple calls with newlines",
input: `get_weather(location="SF")
get_time(timezone="PST")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: map[string]any{"timezone": "PST"},
},
},
},
},
{
name: "empty input",
input: "",
expected: nil,
},
{
name: "whitespace only",
input: " \n ",
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
calls, err := parseOlmo3FunctionCalls(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(calls, tt.expected); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestParseOlmo3Value(t *testing.T) {
tests := []struct {
name string
input string
expected any
}{
{"string double quotes", `"hello"`, "hello"},
{"string single quotes", `'hello'`, "hello"},
{"integer", "42", int64(42)},
{"negative integer", "-10", int64(-10)},
{"float", "3.14", 3.14},
{"boolean true", "true", true},
{"boolean True", "True", true},
{"boolean false", "false", false},
{"null", "null", nil},
{"None", "None", nil},
{"empty array", "[]", []any{}},
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
{"empty object", "{}", map[string]any{}},
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseOlmo3Value(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(result, tt.expected); diff != "" {
t.Errorf("value mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -1,170 +0,0 @@
package parsers
import (
"context"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type olmo3ThinkParserState int
const (
olmo3CollectingThink olmo3ThinkParserState = iota
olmo3CollectingContent
)
const (
olmo3ThinkCloseTag = "</think>"
)
type Olmo3ThinkParser struct {
state olmo3ThinkParserState
buffer strings.Builder
}
func (p *Olmo3ThinkParser) HasToolSupport() bool {
return false
}
func (p *Olmo3ThinkParser) HasThinkingSupport() bool {
return true
}
func (p *Olmo3ThinkParser) setInitialState(lastMessage *api.Message) {
prefill := lastMessage != nil && lastMessage.Role == "assistant"
// If prefilling with content, skip to content collection
if prefill && lastMessage.Content != "" {
p.state = olmo3CollectingContent
return
}
// Model always thinks first (the <think> tag is injected in the prompt)
p.state = olmo3CollectingThink
}
func (p *Olmo3ThinkParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.setInitialState(lastMessage)
return tools
}
// Event types for internal parser communication
type olmo3Event interface {
isOlmo3Event()
}
type olmo3EventThinkContent struct {
content string
}
type olmo3EventContent struct {
content string
}
func (olmo3EventThinkContent) isOlmo3Event() {}
func (olmo3EventContent) isOlmo3Event() {}
func (p *Olmo3ThinkParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var contentSb strings.Builder
var thinkingSb strings.Builder
for _, event := range events {
switch event := event.(type) {
case olmo3EventThinkContent:
thinkingSb.WriteString(event.content)
case olmo3EventContent:
contentSb.WriteString(event.content)
}
}
return contentSb.String(), thinkingSb.String(), nil, nil
}
func (p *Olmo3ThinkParser) parseEvents() []olmo3Event {
var all []olmo3Event
keepLooping := true
for keepLooping {
var events []olmo3Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Olmo3ThinkParser) eat() ([]olmo3Event, bool) {
var events []olmo3Event
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case olmo3CollectingThink:
if strings.Contains(bufStr, olmo3ThinkCloseTag) {
// Found complete </think> tag
split := strings.SplitN(bufStr, olmo3ThinkCloseTag, 2)
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3CollectingContent
if len(thinking) > 0 {
events = append(events, olmo3EventThinkContent{content: thinking})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3ThinkCloseTag); overlapLen > 0 {
// Partial </think> tag - withhold ambiguous content
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
trailingLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, olmo3EventThinkContent{content: unambiguous})
}
return events, false
} else {
// Regular thinking content - withhold trailing whitespace in case </think> follows
whitespaceLen := trailingWhitespaceLen(bufStr)
ambiguousStart := len(bufStr) - whitespaceLen
unambiguous := bufStr[:ambiguousStart]
ambiguous := bufStr[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, olmo3EventThinkContent{content: unambiguous})
}
return events, false
}
case olmo3CollectingContent:
// Emit all content directly
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, olmo3EventContent{content: bufStr})
}
return events, false
}
return events, false
}

View File

@@ -1,390 +0,0 @@
package parsers
import (
"strings"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3ThinkParser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
lastMessage *api.Message
}{
{
name: "thinking_only",
input: "I need to think about this.</think>Here is my response.",
expectedContent: "Here is my response.",
expectedThinking: "I need to think about this.",
},
{
name: "thinking_with_newlines",
input: "Let me think step by step.\n\n1. First point\n2. Second point</think>The answer is 42.",
expectedContent: "The answer is 42.",
expectedThinking: "Let me think step by step.\n\n1. First point\n2. Second point",
},
{
name: "thinking_then_content",
input: "Deep thinking here.</think>Here is my detailed response with multiple sentences. I have thought carefully.",
expectedContent: "Here is my detailed response with multiple sentences. I have thought carefully.",
expectedThinking: "Deep thinking here.",
},
{
name: "empty_thinking",
input: "</think>Just content here.",
expectedContent: "Just content here.",
expectedThinking: "",
},
{
name: "prefill_skips_thinking",
input: "Continuing from previous content.",
expectedContent: "Continuing from previous content.",
lastMessage: &api.Message{
Role: "assistant",
Content: "Previous content",
},
},
{
name: "thinking_with_whitespace",
input: " Some thinking </think> Content here ",
expectedContent: "Content here ",
expectedThinking: " Some thinking",
},
{
name: "real_model_output_with_newlines",
input: "Yes, that should work. Let me go with that response.\n\n</think>\n\nHi! I'm all set and ready to assist. How about you? How are you today? 😊",
expectedThinking: "Yes, that should work. Let me go with that response.",
expectedContent: "Hi! I'm all set and ready to assist. How about you? How are you today? 😊",
},
// Edge cases
{
name: "nested_think_tags_in_thinking",
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
expectedContent: "more thinking</think>Final content.",
expectedThinking: "I'm thinking <think>nested",
},
{
name: "multiple_think_close_tags",
input: "First thinking</think>Content</think>More content.",
expectedContent: "Content</think>More content.",
expectedThinking: "First thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, tt.lastMessage, nil)
content, thinking, toolCalls, err := parser.Add(tt.input, true)
if err != nil {
t.Fatalf("Add() error = %v", err)
}
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
t.Errorf("content mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
}
// No tool calls expected
if len(toolCalls) > 0 {
t.Errorf("expected no tool calls, got %d", len(toolCalls))
}
})
}
}
func TestOlmo3ThinkParser_Streaming(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, nil, nil)
chunks := []string{
"I am ",
"thinking about",
" this.</think>Here ",
"is the response.",
}
var finalContent, finalThinking strings.Builder
for i, chunk := range chunks {
done := i == len(chunks)-1
content, thinking, _, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
}
expectedContent := "Here is the response."
expectedThinking := "I am thinking about this."
if finalContent.String() != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
}
if finalThinking.String() != expectedThinking {
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
}
}
func TestOlmo3ThinkParser_StreamingEdgeCases(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedThinking string
}{
{
name: "thinking_tag_split_across_chunks",
chunks: []string{
"This is thinking content",
"</think>",
"This is content.",
},
expectedContent: "This is content.",
expectedThinking: "This is thinking content",
},
{
name: "thinking_tag_split_mid_token",
chunks: []string{
"Thinking?</",
"think>",
"Content here.",
},
expectedContent: "Content here.",
expectedThinking: "Thinking?",
},
{
name: "thinking_tag_split_at_angle_bracket",
chunks: []string{
"Thinking<",
"/think>",
"Content.",
},
expectedContent: "Content.",
expectedThinking: "Thinking",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, nil, nil)
var finalContent, finalThinking strings.Builder
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, _, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
}
if finalContent.String() != tt.expectedContent {
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
}
})
}
}
// TestOlmo3ThinkParser_ThinkBoundary tests streaming thinking content
// where thinking chunks come in succession before the </think> tag
func TestOlmo3ThinkParser_ThinkBoundary(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedThinking string
expectedContent string
}{
{
name: "multiple_thinking_chunks",
chunks: []string{
"First part of thinking. ",
"Second part of thinking. ",
"Third part.</think>",
"Content here.",
},
expectedThinking: "First part of thinking. Second part of thinking. Third part.",
expectedContent: "Content here.",
},
{
name: "thinking_chunks_with_newlines",
chunks: []string{
"Step 1: Analyze the problem.\n",
"Step 2: Consider options.\n",
"Step 3: Make decision.</think>",
"Here is my answer.",
},
expectedThinking: "Step 1: Analyze the problem.\nStep 2: Consider options.\nStep 3: Make decision.",
expectedContent: "Here is my answer.",
},
{
name: "single_char_thinking_chunks",
chunks: []string{
"H", "e", "l", "l", "o", "</think>", "World",
},
expectedThinking: "Hello",
expectedContent: "World",
},
{
name: "thinking_with_special_chars",
chunks: []string{
"Let me think... ",
"Option A: $100 ",
"Option B: €200</think>",
"I recommend Option A.",
},
expectedThinking: "Let me think... Option A: $100 Option B: €200",
expectedContent: "I recommend Option A.",
},
{
name: "long_thinking_multiple_chunks",
chunks: []string{
"This is a very long thinking process. ",
"I need to consider many factors. ",
"First, let me look at the data. ",
"The numbers show interesting patterns. ",
"Based on my analysis, ",
"I can conclude that...</think>",
"The answer is 42.",
},
expectedThinking: "This is a very long thinking process. I need to consider many factors. First, let me look at the data. The numbers show interesting patterns. Based on my analysis, I can conclude that...",
expectedContent: "The answer is 42.",
},
{
name: "thinking_ends_exactly_at_chunk_boundary",
chunks: []string{
"Thinking content",
"</think>",
"Content",
},
expectedThinking: "Thinking content",
expectedContent: "Content",
},
{
name: "empty_chunks_between_thinking",
chunks: []string{
"Start thinking",
"",
" middle ",
"",
"end</think>",
"Content",
},
expectedThinking: "Start thinking middle end",
expectedContent: "Content",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, nil, nil)
var finalContent, finalThinking strings.Builder
for i, chunk := range tt.chunks {
done := i == len(tt.chunks)-1
content, thinking, _, err := parser.Add(chunk, done)
if err != nil {
t.Fatalf("Add() error on chunk %d: %v", i, err)
}
finalContent.WriteString(content)
finalThinking.WriteString(thinking)
}
if finalThinking.String() != tt.expectedThinking {
t.Errorf("thinking mismatch:\nexpected: %q\ngot: %q", tt.expectedThinking, finalThinking.String())
}
if finalContent.String() != tt.expectedContent {
t.Errorf("content mismatch:\nexpected: %q\ngot: %q", tt.expectedContent, finalContent.String())
}
})
}
}
// TestOlmo3ThinkParser_StateTransitions tests that state transitions work correctly
func TestOlmo3ThinkParser_StateTransitions(t *testing.T) {
t.Run("thinking_to_content", func(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, nil, nil)
if parser.state != olmo3CollectingThink {
t.Errorf("initial state should be olmo3CollectingThink, got %v", parser.state)
}
parser.Add("thinking</think>content", true)
if parser.state != olmo3CollectingContent {
t.Errorf("state after </think> should be olmo3CollectingContent, got %v", parser.state)
}
})
}
func TestOlmo3ThinkParser_HasToolSupport(t *testing.T) {
parser := &Olmo3ThinkParser{}
if parser.HasToolSupport() {
t.Error("Olmo3ThinkParser should NOT support tools")
}
}
func TestOlmo3ThinkParser_HasThinkingSupport(t *testing.T) {
parser := &Olmo3ThinkParser{}
if !parser.HasThinkingSupport() {
t.Error("Olmo3ThinkParser should support thinking")
}
}
func TestOlmo3ThinkParser_Init(t *testing.T) {
parser := &Olmo3ThinkParser{}
tools := []api.Tool{
{Function: api.ToolFunction{Name: "test_tool"}},
}
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
returnedTools := parser.Init(tools, lastMessage, nil)
if len(returnedTools) != len(tools) {
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
}
// Should be in content collection mode due to prefill
if parser.state != olmo3CollectingContent {
t.Errorf("expected state olmo3CollectingContent, got %v", parser.state)
}
}
func TestOlmo3ThinkParser_InitWithoutPrefill(t *testing.T) {
parser := &Olmo3ThinkParser{}
parser.Init(nil, nil, nil)
// Should be in thinking collection mode (model always thinks first)
if parser.state != olmo3CollectingThink {
t.Errorf("expected state olmo3CollectingThink, got %v", parser.state)
}
}

View File

@@ -58,10 +58,6 @@ func ParserForName(name string) Parser {
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
case "olmo3-think":
return &Olmo3ThinkParser{}
case "olmo3":
return &Olmo3Parser{}
default:
return nil
}

View File

@@ -1,45 +0,0 @@
package renderers
import "encoding/json"
// marshalWithSpaces marshals v to JSON and adds a space after each ':' and ','
// that appears outside of string values. This matches the formatting expected
// by certain model architectures.
func marshalWithSpaces(v any) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
out := make([]byte, 0, len(b)+len(b)/8)
inStr, esc := false, false
for _, c := range b {
if inStr {
out = append(out, c)
if esc {
esc = false
continue
}
if c == '\\' {
esc = true
continue
}
if c == '"' {
inStr = false
}
continue
}
switch c {
case '"':
inStr = true
out = append(out, c)
case ':':
out = append(out, ':', ' ')
case ',':
out = append(out, ',', ' ')
default:
out = append(out, c)
}
}
return out, nil
}

View File

@@ -1,148 +0,0 @@
package renderers
import (
"encoding/json"
"fmt"
"sort"
"strings"
"github.com/ollama/ollama/api"
)
const (
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
)
type Olmo3Renderer struct{}
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
var systemMessage *api.Message
filteredMessages := make([]api.Message, 0, len(messages))
for i, message := range messages {
if message.Role == "system" {
if systemMessage == nil {
systemMessage = &messages[i]
}
continue
}
filteredMessages = append(filteredMessages, message)
}
// Render system message
if systemMessage != nil {
// Custom system message - single newline after "system"
sb.WriteString("<|im_start|>system\n")
sb.WriteString(systemMessage.Content)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString("<functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
}
sb.WriteString("<|im_end|>\n")
} else {
// Default system message - single newline after "system"
sb.WriteString("<|im_start|>system\n")
sb.WriteString(olmo3DefaultSystemMessage)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString(olmo3WithFunctionsMessage)
sb.WriteString("<functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
} else {
sb.WriteString(olmo3NoFunctionsMessage)
sb.WriteString("<functions></functions>")
}
sb.WriteString("<|im_end|>\n")
}
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
switch message.Role {
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
sb.WriteString("<function_calls>")
for j, tc := range message.ToolCalls {
// Format as function_name(arg1="value1", arg2="value2")
sb.WriteString(tc.Function.Name)
sb.WriteString("(")
// Get sorted keys for deterministic output
keys := make([]string, 0, len(tc.Function.Arguments))
for k := range tc.Function.Arguments {
keys = append(keys, k)
}
sort.Strings(keys)
for k, key := range keys {
if k > 0 {
sb.WriteString(", ")
}
value, err := json.Marshal(tc.Function.Arguments[key])
if err != nil {
return "", err
}
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
}
sb.WriteString(")")
if j < len(message.ToolCalls)-1 {
sb.WriteString("\n")
}
}
sb.WriteString("</function_calls>")
}
// Add end tag unless it's the last message with content only (prefill)
if !lastMessage || len(message.ToolCalls) > 0 {
sb.WriteString("<|im_end|>\n")
}
case "tool":
sb.WriteString("<|im_start|>environment\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
}
}
// Add generation prompt if needed
needsGenerationPrompt := true
if len(filteredMessages) > 0 {
lastMsg := filteredMessages[len(filteredMessages)-1]
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
needsGenerationPrompt = false
}
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n\n")
}
return sb.String(), nil
}

View File

@@ -1,290 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3Renderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic without system - adds default system",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "with system message no tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "with system message and tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "default system with tools - includes function instruction",
msgs: []api.Message{
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "assistant with tool calls - function call syntax",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather in SF?"},
{
Role: "assistant",
Content: "Let me check the weather.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"location": "San Francisco",
},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather in SF?<|im_end|>\n" +
"<|im_start|>assistant\n" +
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "multi-turn conversation",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant.<|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Hi there!<|im_end|>\n" +
"<|im_start|>user\n" +
"How are you?<|im_end|>\n" +
"<|im_start|>assistant\n\n",
},
{
name: "parallel tool calls - newline separated",
msgs: []api.Message{
{Role: "user", Content: "Get weather in SF and NYC"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Get weather in SF and NYC<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>get_weather(location="San Francisco")` + "\n" +
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 55}<|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "tool call with multiple arguments",
msgs: []api.Message{
{Role: "user", Content: "Book a flight"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
},
},
},
},
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "book_flight",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"from": {Type: api.PropertyType{"string"}},
"to": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. " +
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Book a flight<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
"<|im_start|>assistant\n\n",
},
{
name: "assistant prefill - no generation prompt",
msgs: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Hi there!",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -1,130 +0,0 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
const (
olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions."
)
type Olmo3ThinkRenderer struct{}
type olmo3ThinkToolCall struct {
ID string `json:"id,omitempty"`
Type string `json:"type,omitempty"`
Function olmo3ThinkToolCallFunc `json:"function"`
}
type olmo3ThinkToolCallFunc struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
var sb strings.Builder
var systemMessage *api.Message
filteredMessages := make([]api.Message, 0, len(messages))
for i, message := range messages {
if message.Role == "system" {
if systemMessage == nil {
systemMessage = &messages[i]
}
continue
}
filteredMessages = append(filteredMessages, message)
}
systemContent := olmo3ThinkDefaultSystemMessage
if systemMessage != nil {
systemContent = systemMessage.Content
}
sb.WriteString("<|im_start|>system\n")
sb.WriteString(systemContent)
if len(tools) > 0 {
functionsJSON, err := marshalWithSpaces(tools)
if err != nil {
return "", err
}
sb.WriteString(" <functions>")
sb.WriteString(string(functionsJSON))
sb.WriteString("</functions>")
} else {
sb.WriteString(olmo3ThinkNoFunctionsMessage)
sb.WriteString(" <functions></functions>")
}
sb.WriteString("<|im_end|>\n")
for i, message := range filteredMessages {
lastMessage := i == len(filteredMessages)-1
switch message.Role {
case "user":
sb.WriteString("<|im_start|>user\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
case "assistant":
sb.WriteString("<|im_start|>assistant\n")
if message.Content != "" {
sb.WriteString(message.Content)
}
if len(message.ToolCalls) > 0 {
toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
for j, tc := range message.ToolCalls {
argsJSON, err := json.Marshal(tc.Function.Arguments)
if err != nil {
return "", err
}
toolCalls[j] = olmo3ThinkToolCall{
ID: tc.ID,
Type: "function",
Function: olmo3ThinkToolCallFunc{
Name: tc.Function.Name,
Arguments: string(argsJSON),
},
}
}
toolCallsJSON, err := marshalWithSpaces(toolCalls)
if err != nil {
return "", err
}
sb.WriteString("<function_calls>")
sb.WriteString(string(toolCallsJSON))
sb.WriteString("</function_calls>")
}
if !lastMessage {
sb.WriteString("<|im_end|>\n")
}
case "tool":
sb.WriteString("<|im_start|>environment\n")
sb.WriteString(message.Content)
sb.WriteString("<|im_end|>\n")
}
}
needsGenerationPrompt := true
if len(filteredMessages) > 0 {
lastMsg := filteredMessages[len(filteredMessages)-1]
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
needsGenerationPrompt = false
}
}
if needsGenerationPrompt {
sb.WriteString("<|im_start|>assistant\n<think>")
}
return sb.String(), nil
}

View File

@@ -1,224 +0,0 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3ThinkRenderer(t *testing.T) {
tests := []struct {
name string
msgs []api.Message
tools []api.Tool
expected string
}{
{
name: "basic without system - adds default system",
msgs: []api.Message{
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "with system message no tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello!"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello!<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "with system message and tools",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "assistant with tool calls",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "What is the weather in SF?"},
{
Role: "assistant",
Content: "Let me check the weather.",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{
"location": "San Francisco",
},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"What is the weather in SF?<|im_end|>\n" +
"<|im_start|>assistant\n" +
`Let me check the weather.<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "multi-turn conversation",
msgs: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
expected: "<|im_start|>system\n" +
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Hello<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Hi there!<|im_end|>\n" +
"<|im_start|>user\n" +
"How are you?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "parallel tool calls",
msgs: []api.Message{
{Role: "user", Content: "Get weather in SF and NYC"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
ID: "call_1",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
ID: "call_2",
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {Type: api.PropertyType{"string"}},
},
},
},
},
},
expected: "<|im_start|>system\n" +
`You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. <functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
"<|im_start|>user\n" +
"Get weather in SF and NYC<|im_end|>\n" +
"<|im_start|>assistant\n" +
`<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]</function_calls><|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 68}<|im_end|>` + "\n" +
"<|im_start|>environment\n" +
`{"temperature": 55}<|im_end|>` + "\n" +
"<|im_start|>assistant\n" +
"<think>",
},
{
name: "assistant message only content no tool calls",
msgs: []api.Message{
{Role: "user", Content: "Tell me a joke"},
{Role: "assistant", Content: "Why did the chicken cross the road?"},
{Role: "user", Content: "I don't know, why?"},
},
expected: "<|im_start|>system\n" +
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
"<|im_start|>user\n" +
"Tell me a joke<|im_end|>\n" +
"<|im_start|>assistant\n" +
"Why did the chicken cross the road?<|im_end|>\n" +
"<|im_start|>user\n" +
"I don't know, why?<|im_end|>\n" +
"<|im_start|>assistant\n" +
"<think>",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -1,11 +1,51 @@
package renderers
import (
"encoding/json"
"strings"
"github.com/ollama/ollama/api"
)
func marshalWithSpaces(v any) ([]byte, error) {
b, err := json.Marshal(v)
if err != nil {
return nil, err
}
out := make([]byte, 0, len(b)+len(b)/8)
inStr, esc := false, false
for _, c := range b {
if inStr {
out = append(out, c)
if esc {
esc = false
continue
}
if c == '\\' {
esc = true
continue
}
if c == '"' {
inStr = false
}
continue
}
switch c {
case '"':
inStr = true
out = append(out, c)
case ':':
out = append(out, ':', ' ')
case ',':
out = append(out, ',', ' ')
default:
out = append(out, c)
}
}
return out, nil
}
type Qwen3VLRenderer struct {
isThinking bool

View File

@@ -6,6 +6,7 @@ import (
"github.com/google/go-cmp/cmp"
)
// TODO(drifkin): this will be moved to utils in the near future and used by other renderers as well
func TestMarshalWithSpaces(t *testing.T) {
tests := []struct {
name string

View File

@@ -59,12 +59,6 @@ func rendererForName(name string) Renderer {
case "cogito":
renderer := &CogitoRenderer{isThinking: true}
return renderer
case "olmo3-think":
renderer := &Olmo3ThinkRenderer{}
return renderer
case "olmo3":
renderer := &Olmo3Renderer{}
return renderer
default:
return nil
}

View File

@@ -300,13 +300,18 @@ func filesForModel(path string) ([]string, error) {
}
files = append(files, js...)
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
// tokenizer.model might be a unresolved git lfs reference; error if it is
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
// only include tokenizer.model is tokenizer.json is not present
if !slices.ContainsFunc(files, func(s string) bool {
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
}) {
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
// tokenizer.model might be a unresolved git lfs reference; error if it is
files = append(files, tks...)
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
files = append(files, tks...)
}
}
return files, nil

View File

@@ -888,37 +888,6 @@ func TestFilesForModel(t *testing.T) {
"tokenizer.json",
},
},
{
name: "safetensors with both tokenizer.json and tokenizer.model",
setup: func(dir string) error {
// Create binary content for tokenizer.model (application/octet-stream)
binaryContent := make([]byte, 512)
for i := range binaryContent {
binaryContent[i] = byte(i % 256)
}
files := []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
}
for _, file := range files {
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
return err
}
}
// Write tokenizer.model as binary
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
return err
}
return nil
},
wantFiles: []string{
"model-00001-of-00001.safetensors",
"config.json",
"tokenizer.json",
"tokenizer.model",
},
},
{
name: "safetensors with consolidated files - prefers model files",
setup: func(dir string) error {