Compare commits
2 Commits
parth/olmo
...
mxyng/remo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3dcc31dfac | ||
|
|
260b165a1a |
@@ -200,8 +200,6 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
|
||||
conv = &olmoModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
|
||||
@@ -2,7 +2,6 @@ package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
@@ -27,26 +26,16 @@ type gemma3Model struct {
|
||||
NumChannels uint32 `json:"num_channels"` // num_channels 3
|
||||
PatchSize uint32 `json:"patch_size"` // patch_size 14
|
||||
} `json:"vision_config"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
RopeScaling *struct {
|
||||
Type string `json:"rope_type"`
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
} `json:"rope_scaling"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
HeadDim uint32 `json:"head_dim"`
|
||||
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
|
||||
RopeLocalTheta float32 `json:"rope_local_base_freq"`
|
||||
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -92,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["gemma3.attention.key_length"] = p.HeadDim
|
||||
kv["gemma3.attention.value_length"] = p.HeadDim
|
||||
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
|
||||
|
||||
// The sliding window pattern is either provided as the sliding_window_pattern
|
||||
// key (an int) or as the layer_types key (a list of strings).
|
||||
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
|
||||
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
|
||||
for i := range numBlocks {
|
||||
var isLocal bool
|
||||
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
|
||||
isLocal = p.LayerTypes[i] == "sliding_attention"
|
||||
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
|
||||
isLocal = (i+1)%*p.SlidingWindowPattern != 0
|
||||
}
|
||||
if !yield(isLocal) {
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
if p.FinalLogitSoftcap > 0 {
|
||||
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
|
||||
}
|
||||
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
|
||||
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
|
||||
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
|
||||
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
|
||||
kv["gemma3.rope.scaling.type"] = "yarn"
|
||||
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
|
||||
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
|
||||
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
|
||||
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
|
||||
}
|
||||
|
||||
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
|
||||
kv["gemma3.embedding_length"] = p.HiddenSize
|
||||
kv["gemma3.feed_forward_length"] = p.IntermediateSize
|
||||
default:
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type ropeScaling struct {
|
||||
Factor float32 `json:"factor"`
|
||||
OriginalMaxPositionEmbeds uint32 `json:"original_max_position_embeddings"`
|
||||
AttentionFactor float32 `json:"attention_factor"`
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
RopeType string `json:"rope_type"`
|
||||
ExtrapolationFactor float32 `json:"extrapolation_factor"`
|
||||
}
|
||||
|
||||
type olmoModel struct {
|
||||
ModelParameters
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
RopeScaling *ropeScaling `json:"rope_scaling"`
|
||||
ClampKQV float32 `json:"f_clamp_kqv"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo2"
|
||||
kv["olmo2.block_count"] = p.NumHiddenLayers
|
||||
kv["olmo2.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["olmo2.embedding_length"] = p.HiddenSize
|
||||
kv["olmo2.feed_forward_length"] = p.IntermediateSize
|
||||
kv["olmo2.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["olmo2.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["olmo2.rope.freq_base"] = p.RopeTheta
|
||||
} else {
|
||||
kv["olmo2.rope.freq_base"] = float32(10000.0)
|
||||
}
|
||||
|
||||
if p.RopeScaling != nil {
|
||||
if p.RopeScaling.Factor > 0 {
|
||||
kv["olmo2.rope.scaling.factor"] = p.RopeScaling.Factor
|
||||
}
|
||||
if p.RopeScaling.OriginalMaxPositionEmbeds > 0 {
|
||||
kv["olmo2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeds
|
||||
}
|
||||
if p.RopeScaling.AttentionFactor > 0 {
|
||||
kv["olmo2.rope.scaling.attn_factor"] = p.RopeScaling.AttentionFactor
|
||||
}
|
||||
if p.RopeScaling.RopeType != "" {
|
||||
kv["olmo2.rope.scaling.type"] = p.RopeScaling.RopeType
|
||||
}
|
||||
}
|
||||
|
||||
if p.RMSNormEPS > 0 {
|
||||
kv["olmo2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
}
|
||||
|
||||
if p.ClampKQV > 0 {
|
||||
kv["olmo2.attention.clamp_kqv"] = p.ClampKQV
|
||||
}
|
||||
|
||||
if p.SlidingWindow > 0 {
|
||||
kv["olmo2.attention.sliding_window"] = p.SlidingWindow
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
slidingPattern := make([]bool, len(p.LayerTypes))
|
||||
for i, layerType := range p.LayerTypes {
|
||||
slidingPattern[i] = (layerType == "sliding_attention")
|
||||
}
|
||||
kv["olmo2.attention.sliding_window_pattern"] = slidingPattern
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *olmoModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"model.norm", "output_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -252,7 +252,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
"olmo2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
||||
9896
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
9896
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal
vendored
File diff suppressed because it is too large
Load Diff
@@ -2,7 +2,7 @@
|
||||
|
||||
package metal
|
||||
|
||||
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal"
|
||||
//go:generate sh -c "{ echo // Code generated by 'go generate'. DO NOT EDIT.; sed -e '/__embed_ggml-common.h__/r ../ggml-common.h' -e '/__embed_ggml-common.h__/d' -e '/#include \"ggml-metal-impl.h\"/r ggml-metal-impl.h' -e '/#include \"ggml-metal-impl.h\"/d' ggml-metal.metal; } >ggml-metal-embed.metal; rm ggml-metal.metal"
|
||||
|
||||
// #cgo CXXFLAGS: -std=c++17
|
||||
// #cgo CPPFLAGS: -DGGML_METAL_NDEBUG -DGGML_METAL_EMBED_LIBRARY -DGGML_METAL_HAS_BF16 -I.. -I../../include
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
package nn
|
||||
// fast provides implementations of fast (fused) operations for increased performance.
|
||||
package fast
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -7,7 +8,7 @@ import (
|
||||
|
||||
// fastRoPE is an interface for tensors that support fast rotary positional embedding.
|
||||
type fastRoPE interface {
|
||||
RoPE(ctx ml.Context, positions ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
|
||||
RoPE(ctx ml.Context, positionIDs ml.Tensor, dim int, base, scale float32, options ...func(*rope.Options)) ml.Tensor
|
||||
}
|
||||
|
||||
// RoPE applies rotary positional embedding to tensor `t`.
|
||||
@@ -1,4 +1,3 @@
|
||||
// Package rope provides options for RoPE
|
||||
package rope
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
@@ -58,18 +57,6 @@ func WithAttentionFactor(attentionFactor float32) func(*Options) {
|
||||
}
|
||||
}
|
||||
|
||||
func WithBetaFast(betaFast float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.BetaFast = betaFast
|
||||
}
|
||||
}
|
||||
|
||||
func WithBetaSlow(betaSlow float32) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.YaRN.BetaSlow = betaSlow
|
||||
}
|
||||
}
|
||||
|
||||
func WithMRoPE(sections []int) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Type |= 1 << 3
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -41,12 +42,13 @@ type Options struct {
|
||||
kqScale float64
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.qkRopeHeadDim, o.ropeBase, 1./o.ropeScale,
|
||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
return []func(*rope.Options){
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(float32(1.0/(1.0+0.1*math.Log(float64(o.ropeScale))))),
|
||||
)
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
}
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
@@ -89,8 +91,8 @@ func (attn *Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor
|
||||
compressedKV.Stride(1), compressedKV.Dim(1),
|
||||
)
|
||||
|
||||
qRot := opts.applyRotaryPositionEmbeddings(ctx, queryChunks[1], positions)
|
||||
kRot = opts.applyRotaryPositionEmbeddings(ctx, kRot, positions)
|
||||
qRot := fast.RoPE(ctx, queryChunks[1], positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kRot = fast.RoPE(ctx, kRot, positions, opts.qkRopeHeadDim, opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
kPass = attn.KVANorm.Forward(ctx, kPass, opts.eps)
|
||||
|
||||
var attention ml.Tensor
|
||||
@@ -325,7 +327,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return fast.RoPE(ctx, key, shift, m.qkRopeHeadDim, m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
@@ -19,7 +20,7 @@ type textModel struct {
|
||||
}
|
||||
|
||||
func (m *textModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.Options.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, shift), nil
|
||||
}
|
||||
|
||||
type textOptions struct {
|
||||
@@ -37,8 +38,8 @@ func (o textOptions) headDim() int {
|
||||
return o.hiddenSize / o.numHeads
|
||||
}
|
||||
|
||||
func (o textOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
||||
func (o textOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
type textBlock struct {
|
||||
@@ -82,8 +83,8 @@ func (m *textAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tenso
|
||||
value := m.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, -1)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, -1, attention.Dim(2))
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -21,10 +22,6 @@ type Options struct {
|
||||
largeModelScaling bool
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.SentencePiece
|
||||
@@ -91,7 +88,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -101,7 +98,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -131,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -16,7 +16,7 @@ import (
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
model.SentencePiece
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*TextModel
|
||||
@@ -54,35 +54,24 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
var processor model.TextProcessor
|
||||
switch c.String("tokenizer.ggml.model") {
|
||||
case "gpt2":
|
||||
processor = model.NewBytePairEncoding(&vocabulary)
|
||||
default:
|
||||
// Previous uploads of Gemma 3 on Ollama did not have token 106
|
||||
// (i.e. "<end_of_turn>") so we need to add in case it's not already present
|
||||
vocabulary.EOS = append(vocabulary.EOS, int32(c.Uint("tokenizer.ggml.eot_token_id", 106)))
|
||||
processor = model.NewSentencePiece(&vocabulary)
|
||||
}
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
SentencePiece: model.NewSentencePiece(
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{
|
||||
int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
int32(c.Uint("tokenizer.ggml.eot_token_id", 106)),
|
||||
},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
VisionModel: newVisionModel(c),
|
||||
TextModel: newTextModel(c),
|
||||
@@ -152,16 +141,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
hiddenState := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
hiddenState = m.Output.Forward(ctx, hiddenState)
|
||||
|
||||
if m.TextConfig.finalLogitSoftcap > 0.0 {
|
||||
hiddenState = hiddenState.Scale(ctx, 1.0/float64(m.TextConfig.finalLogitSoftcap))
|
||||
hiddenState = hiddenState.Tanh(ctx)
|
||||
hiddenState = hiddenState.Scale(ctx, float64(m.TextConfig.finalLogitSoftcap))
|
||||
}
|
||||
|
||||
return hiddenState, nil
|
||||
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
|
||||
return m.Output.Forward(ctx, hiddenStates), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -2,12 +2,12 @@ package gemma3
|
||||
|
||||
import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -16,32 +16,8 @@ type TextConfig struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
attnKeyLen, attnValLen int
|
||||
eps, ropeScale float32
|
||||
ropeLocalBase float32
|
||||
ropeLocalBase, ropeGlobalBase float32
|
||||
largeModelScaling bool
|
||||
slidingWindowPattern []bool
|
||||
ropeBase float32
|
||||
ropeType string
|
||||
ropeOriginalContext int
|
||||
ropeExtrapolation float32
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
finalLogitSoftcap float32
|
||||
}
|
||||
|
||||
func (o TextConfig) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor, base float32) ml.Tensor {
|
||||
ropeOpts := []func(*rope.Options){rope.WithTypeNeoX()}
|
||||
if o.ropeType == "yarn" {
|
||||
attnFactor := float32(1.0 / (1.0 + 0.1*math.Log(float64(o.ropeScale))))
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(o.ropeOriginalContext),
|
||||
rope.WithExtrapolationFactor(o.ropeExtrapolation),
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
rope.WithBetaFast(o.ropeBetaFast),
|
||||
rope.WithBetaSlow(o.ropeBetaSlow),
|
||||
)
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, o.attnKeyLen, base, 1./o.ropeScale, ropeOpts...)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -69,35 +45,21 @@ func newTextModel(c fs.Config) *TextModel {
|
||||
m := TextModel{
|
||||
Layers: make([]TextLayer, numBlocks),
|
||||
TextConfig: &TextConfig{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeBase: c.Float("rope.freq_base", 1000000.0),
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
ropeType: c.String("rope.scaling.type"),
|
||||
ropeOriginalContext: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeExtrapolation: c.Float("rope.scaling.extrapolation_factor", 1.0),
|
||||
ropeBetaFast: c.Float("rope.scaling.beta_fast", 64.0),
|
||||
ropeBetaSlow: c.Float("rope.scaling.beta_slow", 1.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
finalLogitSoftcap: c.Float("final_logit_softcapping", 0.0),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
attnKeyLen: int(c.Uint("attention.key_length", 256)),
|
||||
attnValLen: int(c.Uint("attention.value_length", 256)),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: 1,
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
// Google's Gemma 3 release with sliding window attention does
|
||||
// not use final logit softcapping, and so force it to 0.0
|
||||
// TODO (jmorganca): this should ideally be set to 0.0 in the
|
||||
// model configuration instead of here, as future versions of
|
||||
// models may include both sliding window attention and final
|
||||
// logit softcapping.
|
||||
if slices.Contains(m.TextConfig.slidingWindowPattern, true) {
|
||||
m.TextConfig.finalLogitSoftcap = 0.0
|
||||
}
|
||||
|
||||
if numBlocks == gemma27BLayerCount {
|
||||
m.largeModelScaling = true
|
||||
}
|
||||
@@ -114,31 +76,18 @@ type TextSelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (opts *TextConfig) ropeBaseForLayer(layer int) float32 {
|
||||
if opts.slidingWindowPattern != nil && opts.slidingWindowPattern[layer] {
|
||||
return opts.ropeLocalBase
|
||||
}
|
||||
|
||||
// Standard Gemma3: only every n-th layer is global,
|
||||
// where n = gemmaGlobalCacheCount, otherwise use
|
||||
// the local rope base
|
||||
if (layer+1)%gemmaGlobalCacheCount > 0 {
|
||||
return opts.ropeLocalBase
|
||||
}
|
||||
|
||||
// default to global rope base
|
||||
return opts.ropeBase
|
||||
}
|
||||
|
||||
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextConfig) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
|
||||
ropeBase := opts.ropeBaseForLayer(layer)
|
||||
ropeBase := opts.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = opts.ropeGlobalBase
|
||||
}
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs, ropeBase)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@@ -149,7 +98,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs, ropeBase)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.attnKeyLen, ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
|
||||
@@ -162,7 +111,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.TextConfig.ropeBaseForLayer(layer)), nil
|
||||
ropeBase := m.TextConfig.ropeLocalBase
|
||||
if (layer+1)%gemmaGlobalCacheCount == 0 {
|
||||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
@@ -250,5 +204,6 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)
|
||||
}
|
||||
|
||||
return m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return hiddenState
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -94,7 +95,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.ropeBaseLocal
|
||||
}
|
||||
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, ropeBase), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextScaledWordEmbedding struct {
|
||||
@@ -255,14 +256,14 @@ func (attn TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Ten
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
|
||||
query = attn.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, ropeBase)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
var key, value ml.Tensor
|
||||
if !sharedKV {
|
||||
key = attn.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
key = attn.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, ropeBase)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
value = attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
@@ -329,10 +330,6 @@ func (o *TextOptions) isLocal(i int) bool {
|
||||
return o.slidingWindowPattern[i]
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, t, p ml.Tensor, base float32) ml.Tensor {
|
||||
return nn.RoPE(ctx, t, p, o.headDim(), base, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
TextLayers: make([]TextLayer, c.Uint("block_count")),
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -51,7 +52,7 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
|
||||
}
|
||||
|
||||
func (m *Transformer) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return fast.RoPE(ctx, key, shift, m.headDim(), m.ropeBase, 1./m.ropeScale, m.RoPEOptions()...), nil
|
||||
}
|
||||
|
||||
type Options struct {
|
||||
@@ -69,14 +70,14 @@ type Options struct {
|
||||
ropeScale float32
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale,
|
||||
func (o Options) RoPEOptions() []func(*rope.Options) {
|
||||
return []func(*rope.Options){
|
||||
rope.WithTypeNeoX(),
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
// NOTE: ggml sets this implicitly so there's no need to set it here
|
||||
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
|
||||
)
|
||||
// NOTE: ggml sets this implicitly so there's no need to set it here
|
||||
// rope.WithAttentionFactor(0.1*float32(math.Log(float64(o.ropeScale))) + 1.0),
|
||||
}
|
||||
}
|
||||
|
||||
func (o Options) headDim() int {
|
||||
@@ -134,8 +135,8 @@ func (attn *AttentionBlock) Forward(ctx ml.Context, hiddenStates, positions ml.T
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize)
|
||||
}
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim(), opts.ropeBase, 1./opts.ropeScale, opts.RoPEOptions()...)
|
||||
|
||||
attention := nn.AttentionWithSinks(ctx, query, key, value, attn.Sinks, 1/math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -19,10 +20,6 @@ type Options struct {
|
||||
eps, ropeBase, ropeScale float32
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
@@ -118,6 +115,7 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
@@ -128,8 +126,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -137,7 +135,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors), nil
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -32,8 +33,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions, attent
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if useRope {
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
}
|
||||
|
||||
if opts.useQKNorm {
|
||||
@@ -151,10 +152,6 @@ type TextOptions struct {
|
||||
attentionFloorScale float64
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
Layers []TextLayer `gguf:"blk"`
|
||||
|
||||
@@ -239,5 +236,5 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, m.Layers[layer].Attention.RopeFactors), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].Attention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
@@ -19,10 +20,6 @@ type TextOptions struct {
|
||||
ropeScalingBeta float32
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -45,11 +42,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -64,7 +61,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
||||
@@ -16,8 +16,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
type VisionSelfAttention struct {
|
||||
@@ -36,8 +36,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim)), nil)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
)
|
||||
|
||||
@@ -25,11 +26,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions, sa.RopeFactors)
|
||||
query = fast.RoPE(ctx, query, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions, sa.RopeFactors)
|
||||
key = fast.RoPE(ctx, key, positions, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors))
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -43,8 +44,8 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.T
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if layer, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift, layer.SelfAttention.RopeFactors), nil
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(sa.SelfAttention.RopeFactors)), nil
|
||||
}
|
||||
|
||||
return key, nil
|
||||
@@ -205,10 +206,6 @@ type TextModelOptions struct {
|
||||
crossAttentionLayers []int32
|
||||
}
|
||||
|
||||
func (o TextModelOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions, factors ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale, rope.WithFactors(factors))
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Transformer *TextDecoder `gguf:"blk"`
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/olmo"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
@@ -36,10 +37,6 @@ type Options struct {
|
||||
ropeFreqBase float32
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.headDim, o.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
// Single Encoder Layer
|
||||
type EncoderLayer struct {
|
||||
*Attention
|
||||
@@ -108,8 +105,8 @@ func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, positions ml
|
||||
chunks := qkv.Chunk(ctx, 1, opts.numHeads)
|
||||
query, key, value := chunks[0], chunks[1], chunks[2]
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
query = fast.RoPE(ctx, query, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, opts.headDim, opts.ropeFreqBase, 1.0, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(opts.headDim)), nil)
|
||||
|
||||
|
||||
@@ -1,298 +0,0 @@
|
||||
package olmo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
const (
|
||||
cacheTypeSWA = iota
|
||||
cacheTypeCausal
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
// headDim, ropeDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
|
||||
originalContextLength int
|
||||
attnFactor float32
|
||||
|
||||
ropeType string
|
||||
ropeExtrapolation float32
|
||||
|
||||
ropeBetaFast float32
|
||||
ropeBetaSlow float32
|
||||
|
||||
slidingWindowPattern []bool
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
if c.String("tokenizer.ggml.pre") != "default" {
|
||||
pretokenizers = []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
|
||||
hiddenSize := int(c.Uint("embedding_length"))
|
||||
numHeads := int(c.Uint("attention.head_count"))
|
||||
numKVHeads := int(c.Uint("attention.head_count_kv"))
|
||||
eps := c.Float("attention.layer_norm_rms_epsilon")
|
||||
ropeBase := c.Float("rope.freq_base", 1e4)
|
||||
ropeScale := c.Float("rope.scaling.factor", 1)
|
||||
originalContextLength := int(c.Uint("rope.scaling.original_context_length"))
|
||||
attnFactor := c.Float("rope.scaling.attn_factor", 1)
|
||||
ropeType := c.String("rope.scaling.type")
|
||||
ropeExtrapolation := c.Float("rope.scaling.extrapolation_factor", 1.0)
|
||||
|
||||
fmt.Printf("hiddenSize: %d\n", hiddenSize)
|
||||
fmt.Printf("numHeads: %d\n", numHeads)
|
||||
fmt.Printf("numKVHeads: %d\n", numKVHeads)
|
||||
fmt.Printf("eps: %f\n", eps)
|
||||
fmt.Printf("ropeBase: %f\n", ropeBase)
|
||||
fmt.Printf("ropeScale: %f\n", ropeScale)
|
||||
fmt.Printf("originalContextLength: %d\n", originalContextLength)
|
||||
fmt.Printf("attnFactor: %f\n", attnFactor)
|
||||
fmt.Printf("ropeType: %s\n", ropeType)
|
||||
fmt.Printf("ropeExtrapolation: %f\n", ropeExtrapolation)
|
||||
fmt.Printf("sliding_window_pattern: %v\n", c.Bools("attention.sliding_window_pattern"))
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: hiddenSize,
|
||||
numHeads: numHeads,
|
||||
numKVHeads: numKVHeads,
|
||||
eps: eps,
|
||||
ropeBase: ropeBase,
|
||||
ropeScale: ropeScale,
|
||||
originalContextLength: originalContextLength,
|
||||
attnFactor: attnFactor,
|
||||
ropeType: ropeType,
|
||||
ropeExtrapolation: ropeExtrapolation,
|
||||
ropeBetaFast: 32.0,
|
||||
ropeBetaSlow: 1.0,
|
||||
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
|
||||
},
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim int, isSWA bool) ml.Tensor {
|
||||
|
||||
var ropeOpts []func(*rope.Options)
|
||||
|
||||
ropeOpts = append(ropeOpts, rope.WithTypeNeoX())
|
||||
|
||||
// Both SWA and non-SWA use beta_fast and beta_slow
|
||||
// defaults
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithBetaFast(m.ropeBetaFast),
|
||||
rope.WithBetaSlow(m.ropeBetaSlow),
|
||||
)
|
||||
|
||||
// SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
|
||||
// Non-SWA uses full yarn parameters
|
||||
if m.originalContextLength > 0 {
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithOriginalContextLength(m.originalContextLength),
|
||||
)
|
||||
|
||||
// no yarn for swa
|
||||
if isSWA {
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithExtrapolationFactor(0),
|
||||
rope.WithAttentionFactor(1.),
|
||||
)
|
||||
} else {
|
||||
ropeOpts = append(ropeOpts,
|
||||
rope.WithExtrapolationFactor(m.ropeExtrapolation),
|
||||
rope.WithAttentionFactor(m.attnFactor),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
freqScale := float32(1.0)
|
||||
if !isSWA {
|
||||
freqScale = 1. / m.ropeScale
|
||||
}
|
||||
|
||||
return nn.RoPE(ctx, states, positions, ropeDim, m.ropeBase, freqScale, ropeOpts...)
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := m.hiddenSize / m.numHeads
|
||||
ropeDim := headDim
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
// double check type
|
||||
query = sa.QNorm.Forward(ctx, query, m.eps)
|
||||
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
|
||||
//check here
|
||||
|
||||
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = sa.KNorm.Forward(ctx, key, m.eps)
|
||||
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||
key = m.applyRoPE(ctx, key, positions, ropeDim, isSWA)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
|
||||
|
||||
// check attention scaling as well
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := m.hiddenSize / m.numHeads
|
||||
isSWA := m.isSWALayer(layer)
|
||||
return m.applyRoPE(ctx, key, shift, ropeDim, isSWA), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
SelfAttention *SelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
MLP *MLP
|
||||
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
|
||||
|
||||
// return hiddenState
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
// i think this should be after getting the rows?
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
hiddenState = hiddenState.Add(ctx, residual)
|
||||
residual = hiddenState
|
||||
|
||||
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
|
||||
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
|
||||
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
// Olmo3 has Sliding Window Attention (SWA) 3 out of 4 layers.
|
||||
func (m *Model) isSWALayer(layerIdx int) bool {
|
||||
return m.Options.slidingWindowPattern[layerIdx]
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
cacheType := cacheTypeSWA
|
||||
|
||||
isSWA := m.isSWALayer(i)
|
||||
if !isSWA {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
|
||||
wc := m.Cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
// would need to check the cache at the layer instead
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
// TODO: not sure about the index here
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{}})
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
|
||||
|
||||
// return hiddenState, nil
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("olmo2", New)
|
||||
}
|
||||
@@ -1,568 +0,0 @@
|
||||
package olmo
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
typemodel "github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var args struct {
|
||||
model,
|
||||
prompt string
|
||||
layers int
|
||||
}
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
flag.StringVar(&args.model, "model", "", "path to model (e.g., olmo3:latest)")
|
||||
flag.StringVar(&args.prompt, "prompt", "Hello, how are", "model prompt")
|
||||
flag.IntVar(&args.layers, "layers", math.MaxInt, "num of gpu layers")
|
||||
flag.Parse()
|
||||
|
||||
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func blob(tb testing.TB, modelName string) string {
|
||||
tb.Helper()
|
||||
|
||||
models := envconfig.Models()
|
||||
manifest, err := os.Open(filepath.Join(models, "manifests", typemodel.ParseName(modelName).Filepath()))
|
||||
if err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
defer manifest.Close()
|
||||
|
||||
var m struct {
|
||||
Layers []struct {
|
||||
MediaType string `json:"mediaType"`
|
||||
Digest string `json:"digest"`
|
||||
} `json:"layers"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(manifest).Decode(&m); err != nil {
|
||||
tb.Fatal(err)
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
if layer.MediaType == "application/vnd.ollama.image.model" {
|
||||
tb.Log("using model blob", layer.Digest)
|
||||
return filepath.Join(models, "blobs", strings.ReplaceAll(layer.Digest, ":", "-"))
|
||||
}
|
||||
}
|
||||
|
||||
tb.Fatal("model blob not found")
|
||||
return ""
|
||||
}
|
||||
|
||||
func loadFloatsFromBinary(filename string) ([]float32, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
fi, err := f.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if fi.Size()%4 != 0 {
|
||||
return nil, fmt.Errorf("file size %d not multiple of 4", fi.Size())
|
||||
}
|
||||
|
||||
n := int(fi.Size() / 4)
|
||||
floats := make([]float32, n)
|
||||
if err := binary.Read(f, binary.LittleEndian, floats); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return floats, nil
|
||||
}
|
||||
|
||||
func TestTokenization(t *testing.T) {
|
||||
if args.model == "" {
|
||||
t.Skip("no model specified, use -model flag")
|
||||
}
|
||||
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
prompt := args.prompt
|
||||
if prompt == "" {
|
||||
prompt = "hello, how are you?"
|
||||
}
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("prompt: %q", prompt)
|
||||
t.Logf("tokens: %v", tokens)
|
||||
t.Logf("num tokens: %d", len(tokens))
|
||||
|
||||
decoded, err := tp.Decode(tokens)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("decoded: %q", decoded)
|
||||
}
|
||||
|
||||
func TestAttentionForward(t *testing.T) {
|
||||
if args.model == "" {
|
||||
t.Skip("no model specified, use -model flag")
|
||||
}
|
||||
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
olmoModel := m.(*Model)
|
||||
t.Logf("Model options: hiddenSize=%d, numHeads=%d, numKVHeads=%d",
|
||||
olmoModel.hiddenSize, olmoModel.numHeads, olmoModel.numKVHeads)
|
||||
t.Logf("Layer 0 attention: %+v", olmoModel.Layers[0].SelfAttention)
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
// Create test hidden states: (hiddenSize, batchSize)
|
||||
batchSize := 4
|
||||
hiddenSize := olmoModel.hiddenSize
|
||||
hsFloats := make([]float32, hiddenSize*batchSize)
|
||||
for i := range hsFloats {
|
||||
hsFloats[i] = float32(i%100) / 100.0 // Simple test values
|
||||
}
|
||||
|
||||
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
|
||||
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
|
||||
|
||||
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
|
||||
|
||||
// Test attention forward (without cache for simplicity)
|
||||
attentionBlock := olmoModel.Layers[0].SelfAttention
|
||||
isSWA := olmoModel.isSWALayer(0)
|
||||
t.Logf("Layer 0 isSWA: %v", isSWA)
|
||||
|
||||
result := attentionBlock.Forward(ctx, hiddenStates, positions, nil, olmoModel, isSWA)
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("Attention result shape: %v dtype: %v", result.Shape(), result.DType())
|
||||
|
||||
// Optionally dump to file
|
||||
// if err := os.WriteFile("/tmp/olmo_attention_output.bin", result.Bytes(), 0644); err != nil {
|
||||
// t.Fatal(err)
|
||||
// }
|
||||
}
|
||||
|
||||
func TestMLPForward(t *testing.T) {
|
||||
if args.model == "" {
|
||||
t.Skip("no model specified, use -model flag")
|
||||
}
|
||||
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
olmoModel := m.(*Model)
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
// Create test hidden states
|
||||
batchSize := 4
|
||||
hiddenSize := olmoModel.hiddenSize
|
||||
hsFloats := make([]float32, hiddenSize*batchSize)
|
||||
for i := range hsFloats {
|
||||
hsFloats[i] = float32(i%100) / 100.0
|
||||
}
|
||||
|
||||
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
|
||||
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
|
||||
|
||||
mlpBlock := olmoModel.Layers[0].MLP
|
||||
result := mlpBlock.Forward(ctx, hiddenStates, olmoModel)
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("MLP result shape: %v dtype: %v", result.Shape(), result.DType())
|
||||
|
||||
// Parse result bytes to float32
|
||||
resultBytes := result.Bytes()
|
||||
resultFloats := make([]float32, len(resultBytes)/4)
|
||||
for i := range resultFloats {
|
||||
bits := binary.LittleEndian.Uint32(resultBytes[i*4 : (i+1)*4])
|
||||
resultFloats[i] = math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
// Compute statistics
|
||||
var minVal, maxVal, sum float32
|
||||
minVal = resultFloats[0]
|
||||
maxVal = resultFloats[0]
|
||||
for _, v := range resultFloats {
|
||||
if v < minVal {
|
||||
minVal = v
|
||||
}
|
||||
if v > maxVal {
|
||||
maxVal = v
|
||||
}
|
||||
sum += v
|
||||
}
|
||||
mean := sum / float32(len(resultFloats))
|
||||
|
||||
// Build readable output
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# MLP Forward Output\n\n")
|
||||
sb.WriteString(fmt.Sprintf("# Input Shape: [%d, %d] (hiddenSize, batchSize)\n", hiddenSize, batchSize))
|
||||
sb.WriteString(fmt.Sprintf("# Output Shape: %v\n", result.Shape()))
|
||||
sb.WriteString(fmt.Sprintf("# DType: %v\n", result.DType()))
|
||||
sb.WriteString(fmt.Sprintf("# Layer: 0\n\n"))
|
||||
|
||||
sb.WriteString("## Statistics\n\n")
|
||||
sb.WriteString(fmt.Sprintf(" Total elements: %d\n", len(resultFloats)))
|
||||
sb.WriteString(fmt.Sprintf(" Min: %v\n", minVal))
|
||||
sb.WriteString(fmt.Sprintf(" Max: %v\n", maxVal))
|
||||
sb.WriteString(fmt.Sprintf(" Mean: %v\n\n", mean))
|
||||
|
||||
sb.WriteString("## Input Hidden States (first 20 values)\n\n")
|
||||
sb.WriteString(" [")
|
||||
for i := 0; i < min(20, len(hsFloats)); i++ {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%v", hsFloats[i]))
|
||||
}
|
||||
sb.WriteString("]\n\n")
|
||||
|
||||
sb.WriteString("## Output Values\n\n")
|
||||
|
||||
// Per-position output (each position in batch)
|
||||
for pos := 0; pos < batchSize; pos++ {
|
||||
sb.WriteString(fmt.Sprintf("Position %d (hiddenSize=%d values):\n", pos, hiddenSize))
|
||||
|
||||
// Extract values for this position
|
||||
posStart := pos * hiddenSize
|
||||
posEnd := posStart + hiddenSize
|
||||
if posEnd > len(resultFloats) {
|
||||
posEnd = len(resultFloats)
|
||||
}
|
||||
posValues := resultFloats[posStart:posEnd]
|
||||
|
||||
// Full tensor values
|
||||
sb.WriteString(" [")
|
||||
for i, v := range posValues {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%v", v))
|
||||
}
|
||||
sb.WriteString("]\n\n")
|
||||
}
|
||||
|
||||
// Save to file
|
||||
if err := os.WriteFile("/tmp/olmo_mlp_forward.txt", []byte(sb.String()), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Saved /tmp/olmo_mlp_forward.txt")
|
||||
|
||||
// Also save binary
|
||||
if err := os.WriteFile("/tmp/olmo_mlp_forward.bin", resultBytes, 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Saved /tmp/olmo_mlp_forward.bin")
|
||||
|
||||
// Print summary to console
|
||||
fmt.Println(sb.String())
|
||||
}
|
||||
|
||||
func TestFullForward(t *testing.T) {
|
||||
if args.model == "" {
|
||||
t.Skip("no model specified, use -model flag")
|
||||
}
|
||||
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
prompt := args.prompt
|
||||
if prompt == "" {
|
||||
prompt = "Hello, how are you?"
|
||||
}
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("prompt: %q", prompt)
|
||||
t.Logf("tokens: %v", tokens)
|
||||
|
||||
decoded, err := tp.Decode(tokens)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("decoded: %q", decoded)
|
||||
|
||||
seqLen := len(tokens)
|
||||
inputsTensor := ctx.Input().FromInts(tokens, seqLen)
|
||||
positions := make([]int32, seqLen)
|
||||
sequences := make([]int, seqLen)
|
||||
for i := range tokens {
|
||||
positions[i] = int32(i)
|
||||
sequences[i] = 0
|
||||
}
|
||||
|
||||
// Output ALL positions
|
||||
outputIndices := make([]int32, seqLen)
|
||||
for i := range outputIndices {
|
||||
outputIndices[i] = int32(i)
|
||||
}
|
||||
outputs := ctx.Input().FromInts(outputIndices, seqLen)
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: inputsTensor,
|
||||
Positions: positions,
|
||||
Sequences: sequences,
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
// Initialize cache
|
||||
if cache := m.Config().Cache; cache != nil {
|
||||
cache.Init(m.Backend(), ml.DTypeF16, 1, 4096, seqLen)
|
||||
}
|
||||
|
||||
result, err := model.Forward(ctx, m, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
result = result.Contiguous(ctx)
|
||||
ctx.Forward(result).Compute(result)
|
||||
|
||||
t.Logf("Forward pass completed, result shape: %v", result.Shape())
|
||||
|
||||
// Dump logits to binary file
|
||||
if err := os.WriteFile("/tmp/olmo_logits.bin", result.Bytes(), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Saved /tmp/olmo_logits.bin")
|
||||
|
||||
// Parse logits from bytes for detailed analysis
|
||||
logitsBytes := result.Bytes()
|
||||
vocabSize := result.Shape()[0]
|
||||
|
||||
// Read float32 values - shape is (vocab_size, seq_len)
|
||||
allLogits := make([]float32, len(logitsBytes)/4)
|
||||
for i := range allLogits {
|
||||
bits := binary.LittleEndian.Uint32(logitsBytes[i*4 : (i+1)*4])
|
||||
allLogits[i] = math.Float32frombits(bits)
|
||||
}
|
||||
|
||||
// Create detailed text dump matching Python format
|
||||
var sb strings.Builder
|
||||
sb.WriteString("# Full Forward Logits\n\n")
|
||||
sb.WriteString(fmt.Sprintf("# Shape: [1, %d, %d]\n", seqLen, vocabSize))
|
||||
sb.WriteString(fmt.Sprintf("# Layout: (batch=1, seq_len=%d, vocab_size=%d)\n", seqLen, vocabSize))
|
||||
sb.WriteString(fmt.Sprintf("# Prompt: '%s'\n", prompt))
|
||||
sb.WriteString(fmt.Sprintf("# Tokens: %v\n\n", tokens))
|
||||
|
||||
type logitPair struct {
|
||||
tokenID int
|
||||
value float32
|
||||
}
|
||||
|
||||
// Process each position
|
||||
for pos := 0; pos < seqLen; pos++ {
|
||||
// Extract logits for this position
|
||||
// Shape is (vocab_size, seq_len), so logits[v*seqLen + pos] gives logit for vocab v at position pos
|
||||
posLogits := make([]float32, vocabSize)
|
||||
for v := 0; v < vocabSize; v++ {
|
||||
posLogits[v] = allLogits[v*seqLen+pos]
|
||||
}
|
||||
|
||||
// Find top 10 logits
|
||||
pairs := make([]logitPair, len(posLogits))
|
||||
for i, v := range posLogits {
|
||||
pairs[i] = logitPair{tokenID: i, value: v}
|
||||
}
|
||||
// Sort by value descending (simple bubble sort for small top-k)
|
||||
for i := 0; i < min(10, len(pairs)); i++ {
|
||||
for j := i + 1; j < len(pairs); j++ {
|
||||
if pairs[j].value > pairs[i].value {
|
||||
pairs[i], pairs[j] = pairs[j], pairs[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tokenStr, _ := tp.Decode([]int32{tokens[pos]})
|
||||
sb.WriteString(fmt.Sprintf("Position %d (token_id=%d, token='%s'):\n", pos, tokens[pos], tokenStr))
|
||||
sb.WriteString(" Top 10 logits:\n")
|
||||
for i := 0; i < min(10, len(pairs)); i++ {
|
||||
tokStr, _ := tp.Decode([]int32{int32(pairs[i].tokenID)})
|
||||
// Pad token string to 20 chars for alignment
|
||||
paddedTok := fmt.Sprintf("%-20s", fmt.Sprintf("'%s'", tokStr))
|
||||
sb.WriteString(fmt.Sprintf(" %d. token_id=%6d (%s): %f\n", i+1, pairs[i].tokenID, paddedTok, pairs[i].value))
|
||||
}
|
||||
|
||||
// First 20 logits
|
||||
sb.WriteString(" Full logits (first 20): [")
|
||||
for i := 0; i < min(20, len(posLogits)); i++ {
|
||||
if i > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
|
||||
}
|
||||
sb.WriteString("]\n")
|
||||
|
||||
// Last 20 logits
|
||||
sb.WriteString(" Full logits (last 20): [")
|
||||
start := max(0, len(posLogits)-20)
|
||||
for i := start; i < len(posLogits); i++ {
|
||||
if i > start {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
|
||||
}
|
||||
sb.WriteString("]\n\n")
|
||||
}
|
||||
|
||||
if err := os.WriteFile("/tmp/olmo_logits.txt", []byte(sb.String()), 0644); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Log("Saved /tmp/olmo_logits.txt")
|
||||
|
||||
// Print to console as well
|
||||
fmt.Println(sb.String())
|
||||
}
|
||||
|
||||
func TestRoPE(t *testing.T) {
|
||||
if args.model == "" {
|
||||
t.Skip("no model specified, use -model flag")
|
||||
}
|
||||
|
||||
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
olmoModel := m.(*Model)
|
||||
|
||||
// Test RoPE on a simple tensor
|
||||
headDim := olmoModel.hiddenSize / olmoModel.numHeads
|
||||
batchSize := 4
|
||||
numHeads := olmoModel.numHeads
|
||||
|
||||
t.Logf("headDim: %d, numHeads: %d", headDim, numHeads)
|
||||
t.Logf("ropeBase: %f, ropeScale: %f, originalContextLength: %d",
|
||||
olmoModel.ropeBase, olmoModel.ropeScale, olmoModel.originalContextLength)
|
||||
|
||||
// Create test query tensor: (headDim, numHeads, batchSize)
|
||||
queryFloats := make([]float32, headDim*numHeads*batchSize)
|
||||
for i := range queryFloats {
|
||||
queryFloats[i] = float32(i%100) / 100.0
|
||||
}
|
||||
|
||||
// Test 1: Dump initial query values (fresh context)
|
||||
{
|
||||
ctx := m.Backend().NewContext()
|
||||
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
|
||||
t.Logf("query shape: %v", query.Shape())
|
||||
query = query.Contiguous(ctx)
|
||||
ctx.Forward(query).Compute(query)
|
||||
dump := ml.Dump(ctx, query, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
|
||||
t.Logf("Query BEFORE RoPE sample values: %s", dump[:min(500, len(dump))])
|
||||
|
||||
// Write to file
|
||||
header := fmt.Sprintf("Shape: %v\nDType: %v\n\n", query.Shape(), query.DType())
|
||||
if err := os.WriteFile("/tmp/olmo_query_before_rope.txt", []byte(header+dump), 0644); err != nil {
|
||||
t.Errorf("Failed to write file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile("/tmp/olmo_query_before_rope.bin", query.Bytes(), 0644); err != nil {
|
||||
t.Errorf("Failed to write binary file: %v", err)
|
||||
}
|
||||
t.Log("Wrote /tmp/olmo_query_before_rope.txt and .bin")
|
||||
}
|
||||
|
||||
// Test 2: SWA RoPE (fresh context)
|
||||
{
|
||||
ctx := m.Backend().NewContext()
|
||||
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
|
||||
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
|
||||
resultSWA := olmoModel.applyRoPE(ctx, query, positions, headDim, true)
|
||||
resultSWA = resultSWA.Contiguous(ctx)
|
||||
ctx.Forward(resultSWA).Compute(resultSWA)
|
||||
t.Logf("SWA RoPE result shape: %v", resultSWA.Shape())
|
||||
dump := ml.Dump(ctx, resultSWA, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
|
||||
t.Logf("Query AFTER SWA RoPE sample values: %s", dump[:min(500, len(dump))])
|
||||
|
||||
// Write to file
|
||||
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: 1.0 (SWA)\n\n", resultSWA.Shape(), resultSWA.DType())
|
||||
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.txt", []byte(header+dump), 0644); err != nil {
|
||||
t.Errorf("Failed to write file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.bin", resultSWA.Bytes(), 0644); err != nil {
|
||||
t.Errorf("Failed to write binary file: %v", err)
|
||||
}
|
||||
t.Log("Wrote /tmp/olmo_query_after_swa_rope.txt and .bin")
|
||||
}
|
||||
|
||||
// Test 3: Global (non-SWA) RoPE (fresh context)
|
||||
{
|
||||
ctx := m.Backend().NewContext()
|
||||
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
|
||||
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
|
||||
resultGlobal := olmoModel.applyRoPE(ctx, query, positions, headDim, false)
|
||||
resultGlobal = resultGlobal.Contiguous(ctx)
|
||||
ctx.Forward(resultGlobal).Compute(resultGlobal)
|
||||
t.Logf("Global RoPE result shape: %v", resultGlobal.Shape())
|
||||
dump := ml.Dump(ctx, resultGlobal, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
|
||||
t.Logf("Query AFTER Global RoPE sample values: %s", dump[:min(500, len(dump))])
|
||||
|
||||
// Write to file
|
||||
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: %f (Global, 1/ropeScale)\n\n",
|
||||
resultGlobal.Shape(), resultGlobal.DType(), 1.0/olmoModel.ropeScale)
|
||||
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.txt", []byte(header+dump), 0644); err != nil {
|
||||
t.Errorf("Failed to write file: %v", err)
|
||||
}
|
||||
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.bin", resultGlobal.Bytes(), 0644); err != nil {
|
||||
t.Errorf("Failed to write binary file: %v", err)
|
||||
}
|
||||
t.Log("Wrote /tmp/olmo_query_after_global_rope.txt and .bin")
|
||||
}
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -21,10 +22,6 @@ type Options struct {
|
||||
eps, ropeBase, ropeScale float32
|
||||
}
|
||||
|
||||
func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, cmp.Or(o.ropeDim, o.headDim, o.hiddenSize/o.numHeads), o.ropeBase, 1./o.ropeScale, rope.WithTypeNeoX())
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
@@ -35,6 +32,7 @@ type Attention struct {
|
||||
func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenStates.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||
|
||||
query := attn.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
@@ -45,8 +43,8 @@ func (attn Attention) Forward(ctx ml.Context, hiddenStates, positions ml.Tensor,
|
||||
value := attn.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithTypeNeoX())
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
@@ -125,7 +123,8 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (m Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -17,13 +18,6 @@ type TextOptions struct {
|
||||
eps, ropeBase, ropeScale float32
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.ropeDim, o.ropeBase, 1./o.ropeScale,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithTypeNeoX(),
|
||||
)
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
@@ -66,11 +60,11 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = opts.applyRotaryPositionEmbeddings(ctx, q, positionIDs)
|
||||
q = fast.RoPE(ctx, q, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = opts.applyRotaryPositionEmbeddings(ctx, k, positionIDs)
|
||||
k = fast.RoPE(ctx, k, positionIDs, opts.ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithOriginalContextLength(opts.originalContextLength), rope.WithTypeNeoX())
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
@@ -84,7 +78,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
// Shift applies rotary position embeddings to the key tensor for causal attention caching
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return m.applyRotaryPositionEmbeddings(ctx, key, shift), nil
|
||||
return fast.RoPE(ctx, key, shift, m.ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithOriginalContextLength(m.originalContextLength), rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
// MLP implements the feed-forward network component with SwiGLU activation
|
||||
|
||||
@@ -18,8 +18,8 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func blockDiagonalMask(ctx ml.Context, seqLength int, bounds []int, numHeads int) ml.Tensor {
|
||||
@@ -67,8 +67,8 @@ func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenStates, cos, sin, m
|
||||
key = key.Reshape(ctx, opts.headDim, opts.numHeads, key.Dim(1), batchSize)
|
||||
value = value.Reshape(ctx, opts.headDim, opts.numHeads, value.Dim(1), batchSize)
|
||||
|
||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
// Scale factor for scaled dot-product attention
|
||||
scale := 1.0 / math.Sqrt(float64(opts.headDim))
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
@@ -45,7 +46,7 @@ func (o Options) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions
|
||||
rope.WithAttentionFactor(attnFactor),
|
||||
)
|
||||
}
|
||||
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
|
||||
return fast.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1./o.ropeScale, opts...)
|
||||
}
|
||||
|
||||
type Attention struct {
|
||||
|
||||
@@ -195,7 +195,7 @@ func New(c fs.Config) (model.Model, error) {
|
||||
m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, positions ml.Tensor) (ml.Tensor, error) {
|
||||
m.positionCache = nil
|
||||
positions = positions.Repeat(ctx, 1, 4).Reshape(ctx, -1)
|
||||
return m.Options.applyRotaryPositionEmbeddings(ctx, key, positions), nil
|
||||
return m.Options.applyRotaryPositionalEmbedding(ctx, key, positions), nil
|
||||
})
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@@ -34,8 +35,8 @@ func (o TextOptions) headDim() int {
|
||||
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
|
||||
}
|
||||
|
||||
func (o TextOptions) applyRotaryPositionEmbeddings(ctx ml.Context, states, positions ml.Tensor) ml.Tensor {
|
||||
return nn.RoPE(ctx, states, positions, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
|
||||
func (o TextOptions) applyRotaryPositionalEmbedding(ctx ml.Context, t, p ml.Tensor) ml.Tensor {
|
||||
return fast.RoPE(ctx, t, p, o.headDim(), o.ropeBase, 1/float32(math.Sqrt(float64(o.ropeScale))),
|
||||
rope.WithInterleaveMRoPE(o.mropeSections),
|
||||
)
|
||||
}
|
||||
@@ -63,8 +64,8 @@ func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, positions ml.Tens
|
||||
query = sa.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
key = sa.KeyNorm.Forward(ctx, key, opts.eps)
|
||||
|
||||
query = opts.applyRotaryPositionEmbeddings(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionEmbeddings(ctx, key, positions)
|
||||
query = opts.applyRotaryPositionalEmbedding(ctx, query, positions)
|
||||
key = opts.applyRotaryPositionalEmbedding(ctx, key, positions)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache)
|
||||
attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize)
|
||||
|
||||
@@ -23,18 +23,18 @@ func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor {
|
||||
return x2.Scale(ctx, -1).Concat(ctx, x1, 0)
|
||||
}
|
||||
|
||||
func applyRotaryPositionEmbeddings(ctx ml.Context, states, cos, sin ml.Tensor) ml.Tensor {
|
||||
return states.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, states).Mul(ctx, sin))
|
||||
func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor {
|
||||
return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin))
|
||||
}
|
||||
|
||||
func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor {
|
||||
query := sa.Query.Forward(ctx, hiddenStates)
|
||||
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1))
|
||||
query = applyRotaryPositionEmbeddings(ctx, query, cos, sin)
|
||||
query = applyRotaryPositionalEmbedding(ctx, query, cos, sin)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenStates)
|
||||
key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1))
|
||||
key = applyRotaryPositionEmbeddings(ctx, key, cos, sin)
|
||||
key = applyRotaryPositionalEmbedding(ctx, key, cos, sin)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenStates)
|
||||
value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1))
|
||||
|
||||
@@ -1,469 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type olmo3ParserState int
|
||||
|
||||
const (
|
||||
olmo3StateContent olmo3ParserState = iota
|
||||
olmo3StateToolCalls
|
||||
olmo3StateToolCallsDone
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3FuncCallsOpenTag = "<function_calls>"
|
||||
olmo3FuncCallsCloseTag = "</function_calls>"
|
||||
)
|
||||
|
||||
type Olmo3Parser struct {
|
||||
state olmo3ParserState
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.state = olmo3StateContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type olmo3ParserEvent interface {
|
||||
isOlmo3ParserEvent()
|
||||
}
|
||||
|
||||
type olmo3ParserEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type olmo3ParserEventToolCalls struct {
|
||||
calls []api.ToolCall
|
||||
}
|
||||
|
||||
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
|
||||
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
|
||||
|
||||
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
if done {
|
||||
// Drain any remaining content
|
||||
bufStr := p.buffer.String()
|
||||
p.buffer.Reset()
|
||||
if p.state == olmo3StateContent && len(bufStr) > 0 {
|
||||
return bufStr, "", nil, nil
|
||||
}
|
||||
return "", "", nil, nil
|
||||
}
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var allCalls []api.ToolCall
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case olmo3ParserEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
case olmo3ParserEventToolCalls:
|
||||
allCalls = append(allCalls, event.calls...)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), "", allCalls, nil
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
|
||||
var all []olmo3ParserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []olmo3ParserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
|
||||
var events []olmo3ParserEvent
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case olmo3StateContent:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
|
||||
// Found <function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
|
||||
content := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCalls
|
||||
|
||||
if len(content) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: content})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
|
||||
// Partial <function_calls> tag - withhold ambiguous content
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Regular content - emit all
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case olmo3StateToolCalls:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
|
||||
// Found </function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
|
||||
toolCallsStr := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCallsDone
|
||||
|
||||
// Parse the function calls
|
||||
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
|
||||
if err != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
|
||||
} else if len(calls) > 0 {
|
||||
events = append(events, olmo3ParserEventToolCalls{calls: calls})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
|
||||
// Partial </function_calls> tag - wait for more
|
||||
return events, false
|
||||
}
|
||||
// Still collecting tool calls, wait for close tag
|
||||
return events, false
|
||||
|
||||
case olmo3StateToolCallsDone:
|
||||
// After tool calls, emit remaining content
|
||||
p.buffer.Reset()
|
||||
p.state = olmo3StateContent
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
|
||||
// func_name(arg1="value1", arg2=123)
|
||||
// Multiple calls are separated by newlines
|
||||
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
|
||||
var calls []api.ToolCall
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Split by newlines for multiple function calls
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
call, err := parseOlmo3SingleFunctionCall(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
|
||||
}
|
||||
calls = append(calls, call)
|
||||
}
|
||||
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Regex to match function call: func_name(args)
|
||||
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
|
||||
|
||||
// Regex to match a single argument: key=value
|
||||
// Value can be: "string", 'string', number, true, false, null, or nested structures
|
||||
var argRegex = regexp.MustCompile(`^(\w+)=(.+)$`)
|
||||
|
||||
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||
matches := funcCallRegex.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return api.ToolCall{}, fmt.Errorf("invalid function call format")
|
||||
}
|
||||
|
||||
funcName := matches[1]
|
||||
argsStr := matches[2]
|
||||
|
||||
args, err := parseOlmo3Arguments(argsStr)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||
args := make(map[string]any)
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// Split by commas, but respect nested structures and quotes
|
||||
parts := splitArguments(s)
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the first = sign
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part[:eqIdx])
|
||||
valueStr := strings.TrimSpace(part[eqIdx+1:])
|
||||
|
||||
value, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
}
|
||||
|
||||
args[key] = value
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// splitArguments splits arguments by commas, respecting quotes and nested structures
|
||||
func splitArguments(s string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
depth := 0
|
||||
inString := false
|
||||
stringChar := byte(0)
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
|
||||
if escaped {
|
||||
current.WriteByte(c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' && inString {
|
||||
current.WriteByte(c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if (c == '"' || c == '\'') && !inString {
|
||||
inString = true
|
||||
stringChar = c
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if c == stringChar && inString {
|
||||
inString = false
|
||||
stringChar = 0
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString {
|
||||
switch c {
|
||||
case '(', '[', '{':
|
||||
depth++
|
||||
current.WriteByte(c)
|
||||
case ')', ']', '}':
|
||||
depth--
|
||||
current.WriteByte(c)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
continue
|
||||
}
|
||||
current.WriteByte(c)
|
||||
default:
|
||||
current.WriteByte(c)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
|
||||
func parseOlmo3Value(s string) (any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Check for quoted string
|
||||
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
|
||||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
|
||||
// Remove quotes and unescape
|
||||
inner := s[1 : len(s)-1]
|
||||
return unescapeString(inner), nil
|
||||
}
|
||||
|
||||
// Check for boolean
|
||||
if s == "true" || s == "True" {
|
||||
return true, nil
|
||||
}
|
||||
if s == "false" || s == "False" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check for null/None
|
||||
if s == "null" || s == "None" || s == "nil" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for number
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return i, nil
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Check for array [...]
|
||||
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
|
||||
return parseOlmo3Array(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Check for object {...}
|
||||
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
|
||||
return parseOlmo3Object(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Default to string without quotes
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Array(s string) ([]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
parts := splitArguments(s)
|
||||
var arr []any
|
||||
for _, part := range parts {
|
||||
val, err := parseOlmo3Value(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, val)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Object(s string) (map[string]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
// Objects use key: value or "key": value format
|
||||
obj := make(map[string]any)
|
||||
parts := splitArguments(s)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find colon separator
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid object entry: %s", part)
|
||||
}
|
||||
|
||||
keyStr := strings.TrimSpace(part[:colonIdx])
|
||||
valueStr := strings.TrimSpace(part[colonIdx+1:])
|
||||
|
||||
// Remove quotes from key if present
|
||||
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
|
||||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
|
||||
keyStr = keyStr[1 : len(keyStr)-1]
|
||||
}
|
||||
|
||||
val, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
|
||||
}
|
||||
|
||||
obj[keyStr] = val
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func unescapeString(s string) string {
|
||||
// Handle common escape sequences
|
||||
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
|
||||
s = strings.ReplaceAll(s, `\"`, `"`)
|
||||
s = strings.ReplaceAll(s, `\'`, `'`)
|
||||
s = strings.ReplaceAll(s, `\n`, "\n")
|
||||
s = strings.ReplaceAll(s, `\t`, "\t")
|
||||
s = strings.ReplaceAll(s, `\r`, "\r")
|
||||
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
|
||||
return s
|
||||
}
|
||||
@@ -1,483 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Parser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content",
|
||||
input: "Hello, how can I help you?",
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
"date": "2024-01-15",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: `<function_calls>get_weather(location="San Francisco")
|
||||
get_weather(location="New York")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric argument",
|
||||
input: `<function_calls>set_temperature(value=72)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temperature",
|
||||
Arguments: map[string]any{"value": int64(72)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with float argument",
|
||||
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_price",
|
||||
Arguments: map[string]any{"amount": 19.99},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with boolean argument",
|
||||
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle_setting",
|
||||
Arguments: map[string]any{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with null argument",
|
||||
input: `<function_calls>clear_value(field=null)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "clear_value",
|
||||
Arguments: map[string]any{"field": nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with array argument",
|
||||
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_items",
|
||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with dict argument",
|
||||
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update_config",
|
||||
Arguments: map[string]any{
|
||||
"settings": map[string]any{
|
||||
"theme": "dark",
|
||||
"fontSize": int64(14),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with nested dict",
|
||||
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_request",
|
||||
Arguments: map[string]any{
|
||||
"data": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
},
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with no arguments",
|
||||
input: `<function_calls>get_current_time()</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_time",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with single quotes",
|
||||
input: `<function_calls>search(query='hello world')</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "hello world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with escaped quotes",
|
||||
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": `say "hello"`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with mixed argument types",
|
||||
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_user",
|
||||
Arguments: map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Drain remaining content
|
||||
finalContent, finalThinking, finalCalls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
content += finalContent
|
||||
thinking += finalThinking
|
||||
calls = append(calls, finalCalls...)
|
||||
|
||||
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content",
|
||||
chunks: []string{"Hello, ", "how ", "can I help?"},
|
||||
expectedContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call",
|
||||
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call",
|
||||
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
|
||||
expectedContent: "Let me check.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split across chunks",
|
||||
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
var allContent string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range tt.chunks {
|
||||
content, _, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, _, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []api.ToolCall
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple call",
|
||||
input: `get_weather(location="SF")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple args",
|
||||
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "send_email",
|
||||
Arguments: map[string]any{
|
||||
"to": "user@example.com",
|
||||
"subject": "Hello",
|
||||
"body": "Test message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple calls with newlines",
|
||||
input: `get_weather(location="SF")
|
||||
get_time(timezone="PST")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: map[string]any{"timezone": "PST"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " \n ",
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
calls, err := parseOlmo3FunctionCalls(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected any
|
||||
}{
|
||||
{"string double quotes", `"hello"`, "hello"},
|
||||
{"string single quotes", `'hello'`, "hello"},
|
||||
{"integer", "42", int64(42)},
|
||||
{"negative integer", "-10", int64(-10)},
|
||||
{"float", "3.14", 3.14},
|
||||
{"boolean true", "true", true},
|
||||
{"boolean True", "True", true},
|
||||
{"boolean false", "false", false},
|
||||
{"null", "null", nil},
|
||||
{"None", "None", nil},
|
||||
{"empty array", "[]", []any{}},
|
||||
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
|
||||
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
|
||||
{"empty object", "{}", map[string]any{}},
|
||||
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
|
||||
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
|
||||
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseOlmo3Value(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(result, tt.expected); diff != "" {
|
||||
t.Errorf("value mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,170 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type olmo3ThinkParserState int
|
||||
|
||||
const (
|
||||
olmo3CollectingThink olmo3ThinkParserState = iota
|
||||
olmo3CollectingContent
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3ThinkCloseTag = "</think>"
|
||||
)
|
||||
|
||||
type Olmo3ThinkParser struct {
|
||||
state olmo3ThinkParserState
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) HasToolSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) setInitialState(lastMessage *api.Message) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
|
||||
// If prefilling with content, skip to content collection
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = olmo3CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
// Model always thinks first (the <think> tag is injected in the prompt)
|
||||
p.state = olmo3CollectingThink
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
}
|
||||
|
||||
// Event types for internal parser communication
|
||||
type olmo3Event interface {
|
||||
isOlmo3Event()
|
||||
}
|
||||
|
||||
type olmo3EventThinkContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type olmo3EventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (olmo3EventThinkContent) isOlmo3Event() {}
|
||||
func (olmo3EventContent) isOlmo3Event() {}
|
||||
|
||||
func (p *Olmo3ThinkParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case olmo3EventThinkContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case olmo3EventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), nil, nil
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) parseEvents() []olmo3Event {
|
||||
var all []olmo3Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []olmo3Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Olmo3ThinkParser) eat() ([]olmo3Event, bool) {
|
||||
var events []olmo3Event
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case olmo3CollectingThink:
|
||||
if strings.Contains(bufStr, olmo3ThinkCloseTag) {
|
||||
// Found complete </think> tag
|
||||
split := strings.SplitN(bufStr, olmo3ThinkCloseTag, 2)
|
||||
thinking := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
remaining := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3CollectingContent
|
||||
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, olmo3EventThinkContent{content: thinking})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3ThinkCloseTag); overlapLen > 0 {
|
||||
// Partial </think> tag - withhold ambiguous content
|
||||
beforePartialTag := bufStr[:len(bufStr)-overlapLen]
|
||||
trailingLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, olmo3EventThinkContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Regular thinking content - withhold trailing whitespace in case </think> follows
|
||||
whitespaceLen := trailingWhitespaceLen(bufStr)
|
||||
ambiguousStart := len(bufStr) - whitespaceLen
|
||||
|
||||
unambiguous := bufStr[:ambiguousStart]
|
||||
ambiguous := bufStr[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, olmo3EventThinkContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case olmo3CollectingContent:
|
||||
// Emit all content directly
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3EventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
@@ -1,390 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3ThinkParser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
lastMessage *api.Message
|
||||
}{
|
||||
{
|
||||
name: "thinking_only",
|
||||
input: "I need to think about this.</think>Here is my response.",
|
||||
expectedContent: "Here is my response.",
|
||||
expectedThinking: "I need to think about this.",
|
||||
},
|
||||
{
|
||||
name: "thinking_with_newlines",
|
||||
input: "Let me think step by step.\n\n1. First point\n2. Second point</think>The answer is 42.",
|
||||
expectedContent: "The answer is 42.",
|
||||
expectedThinking: "Let me think step by step.\n\n1. First point\n2. Second point",
|
||||
},
|
||||
{
|
||||
name: "thinking_then_content",
|
||||
input: "Deep thinking here.</think>Here is my detailed response with multiple sentences. I have thought carefully.",
|
||||
expectedContent: "Here is my detailed response with multiple sentences. I have thought carefully.",
|
||||
expectedThinking: "Deep thinking here.",
|
||||
},
|
||||
{
|
||||
name: "empty_thinking",
|
||||
input: "</think>Just content here.",
|
||||
expectedContent: "Just content here.",
|
||||
expectedThinking: "",
|
||||
},
|
||||
{
|
||||
name: "prefill_skips_thinking",
|
||||
input: "Continuing from previous content.",
|
||||
expectedContent: "Continuing from previous content.",
|
||||
lastMessage: &api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Previous content",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "thinking_with_whitespace",
|
||||
input: " Some thinking </think> Content here ",
|
||||
expectedContent: "Content here ",
|
||||
expectedThinking: " Some thinking",
|
||||
},
|
||||
{
|
||||
name: "real_model_output_with_newlines",
|
||||
input: "Yes, that should work. Let me go with that response.\n\n</think>\n\nHi! I'm all set and ready to assist. How about you? How are you today? 😊",
|
||||
expectedThinking: "Yes, that should work. Let me go with that response.",
|
||||
expectedContent: "Hi! I'm all set and ready to assist. How about you? How are you today? 😊",
|
||||
},
|
||||
// Edge cases
|
||||
{
|
||||
name: "nested_think_tags_in_thinking",
|
||||
input: "I'm thinking <think>nested</think> more thinking</think>Final content.",
|
||||
expectedContent: "more thinking</think>Final content.",
|
||||
expectedThinking: "I'm thinking <think>nested",
|
||||
},
|
||||
{
|
||||
name: "multiple_think_close_tags",
|
||||
input: "First thinking</think>Content</think>More content.",
|
||||
expectedContent: "Content</think>More content.",
|
||||
expectedThinking: "First thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
parser.Init(nil, tt.lastMessage, nil)
|
||||
|
||||
content, thinking, toolCalls, err := parser.Add(tt.input, true)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error = %v", err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedContent, content); diff != "" {
|
||||
t.Errorf("content mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedThinking, thinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
// No tool calls expected
|
||||
if len(toolCalls) > 0 {
|
||||
t.Errorf("expected no tool calls, got %d", len(toolCalls))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_Streaming(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
chunks := []string{
|
||||
"I am ",
|
||||
"thinking about",
|
||||
" this.</think>Here ",
|
||||
"is the response.",
|
||||
}
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
|
||||
for i, chunk := range chunks {
|
||||
done := i == len(chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
expectedContent := "Here is the response."
|
||||
expectedThinking := "I am thinking about this."
|
||||
|
||||
if finalContent.String() != expectedContent {
|
||||
t.Errorf("expected content %q, got %q", expectedContent, finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", expectedThinking, finalThinking.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_StreamingEdgeCases(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
}{
|
||||
{
|
||||
name: "thinking_tag_split_across_chunks",
|
||||
chunks: []string{
|
||||
"This is thinking content",
|
||||
"</think>",
|
||||
"This is content.",
|
||||
},
|
||||
expectedContent: "This is content.",
|
||||
expectedThinking: "This is thinking content",
|
||||
},
|
||||
{
|
||||
name: "thinking_tag_split_mid_token",
|
||||
chunks: []string{
|
||||
"Thinking?</",
|
||||
"think>",
|
||||
"Content here.",
|
||||
},
|
||||
expectedContent: "Content here.",
|
||||
expectedThinking: "Thinking?",
|
||||
},
|
||||
{
|
||||
name: "thinking_tag_split_at_angle_bracket",
|
||||
chunks: []string{
|
||||
"Thinking<",
|
||||
"/think>",
|
||||
"Content.",
|
||||
},
|
||||
expectedContent: "Content.",
|
||||
expectedThinking: "Thinking",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalContent.String() != tt.expectedContent {
|
||||
t.Errorf("expected content %q, got %q", tt.expectedContent, finalContent.String())
|
||||
}
|
||||
|
||||
if finalThinking.String() != tt.expectedThinking {
|
||||
t.Errorf("expected thinking %q, got %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOlmo3ThinkParser_ThinkBoundary tests streaming thinking content
|
||||
// where thinking chunks come in succession before the </think> tag
|
||||
func TestOlmo3ThinkParser_ThinkBoundary(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedThinking string
|
||||
expectedContent string
|
||||
}{
|
||||
{
|
||||
name: "multiple_thinking_chunks",
|
||||
chunks: []string{
|
||||
"First part of thinking. ",
|
||||
"Second part of thinking. ",
|
||||
"Third part.</think>",
|
||||
"Content here.",
|
||||
},
|
||||
expectedThinking: "First part of thinking. Second part of thinking. Third part.",
|
||||
expectedContent: "Content here.",
|
||||
},
|
||||
{
|
||||
name: "thinking_chunks_with_newlines",
|
||||
chunks: []string{
|
||||
"Step 1: Analyze the problem.\n",
|
||||
"Step 2: Consider options.\n",
|
||||
"Step 3: Make decision.</think>",
|
||||
"Here is my answer.",
|
||||
},
|
||||
expectedThinking: "Step 1: Analyze the problem.\nStep 2: Consider options.\nStep 3: Make decision.",
|
||||
expectedContent: "Here is my answer.",
|
||||
},
|
||||
{
|
||||
name: "single_char_thinking_chunks",
|
||||
chunks: []string{
|
||||
"H", "e", "l", "l", "o", "</think>", "World",
|
||||
},
|
||||
expectedThinking: "Hello",
|
||||
expectedContent: "World",
|
||||
},
|
||||
{
|
||||
name: "thinking_with_special_chars",
|
||||
chunks: []string{
|
||||
"Let me think... ",
|
||||
"Option A: $100 ",
|
||||
"Option B: €200</think>",
|
||||
"I recommend Option A.",
|
||||
},
|
||||
expectedThinking: "Let me think... Option A: $100 Option B: €200",
|
||||
expectedContent: "I recommend Option A.",
|
||||
},
|
||||
{
|
||||
name: "long_thinking_multiple_chunks",
|
||||
chunks: []string{
|
||||
"This is a very long thinking process. ",
|
||||
"I need to consider many factors. ",
|
||||
"First, let me look at the data. ",
|
||||
"The numbers show interesting patterns. ",
|
||||
"Based on my analysis, ",
|
||||
"I can conclude that...</think>",
|
||||
"The answer is 42.",
|
||||
},
|
||||
expectedThinking: "This is a very long thinking process. I need to consider many factors. First, let me look at the data. The numbers show interesting patterns. Based on my analysis, I can conclude that...",
|
||||
expectedContent: "The answer is 42.",
|
||||
},
|
||||
{
|
||||
name: "thinking_ends_exactly_at_chunk_boundary",
|
||||
chunks: []string{
|
||||
"Thinking content",
|
||||
"</think>",
|
||||
"Content",
|
||||
},
|
||||
expectedThinking: "Thinking content",
|
||||
expectedContent: "Content",
|
||||
},
|
||||
{
|
||||
name: "empty_chunks_between_thinking",
|
||||
chunks: []string{
|
||||
"Start thinking",
|
||||
"",
|
||||
" middle ",
|
||||
"",
|
||||
"end</think>",
|
||||
"Content",
|
||||
},
|
||||
expectedThinking: "Start thinking middle end",
|
||||
expectedContent: "Content",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
var finalContent, finalThinking strings.Builder
|
||||
|
||||
for i, chunk := range tt.chunks {
|
||||
done := i == len(tt.chunks)-1
|
||||
content, thinking, _, err := parser.Add(chunk, done)
|
||||
if err != nil {
|
||||
t.Fatalf("Add() error on chunk %d: %v", i, err)
|
||||
}
|
||||
|
||||
finalContent.WriteString(content)
|
||||
finalThinking.WriteString(thinking)
|
||||
}
|
||||
|
||||
if finalThinking.String() != tt.expectedThinking {
|
||||
t.Errorf("thinking mismatch:\nexpected: %q\ngot: %q", tt.expectedThinking, finalThinking.String())
|
||||
}
|
||||
|
||||
if finalContent.String() != tt.expectedContent {
|
||||
t.Errorf("content mismatch:\nexpected: %q\ngot: %q", tt.expectedContent, finalContent.String())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestOlmo3ThinkParser_StateTransitions tests that state transitions work correctly
|
||||
func TestOlmo3ThinkParser_StateTransitions(t *testing.T) {
|
||||
t.Run("thinking_to_content", func(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
if parser.state != olmo3CollectingThink {
|
||||
t.Errorf("initial state should be olmo3CollectingThink, got %v", parser.state)
|
||||
}
|
||||
|
||||
parser.Add("thinking</think>content", true)
|
||||
|
||||
if parser.state != olmo3CollectingContent {
|
||||
t.Errorf("state after </think> should be olmo3CollectingContent, got %v", parser.state)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_HasToolSupport(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
if parser.HasToolSupport() {
|
||||
t.Error("Olmo3ThinkParser should NOT support tools")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_HasThinkingSupport(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
if !parser.HasThinkingSupport() {
|
||||
t.Error("Olmo3ThinkParser should support thinking")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_Init(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
|
||||
tools := []api.Tool{
|
||||
{Function: api.ToolFunction{Name: "test_tool"}},
|
||||
}
|
||||
|
||||
lastMessage := &api.Message{Role: "assistant", Content: "previous"}
|
||||
|
||||
returnedTools := parser.Init(tools, lastMessage, nil)
|
||||
|
||||
if len(returnedTools) != len(tools) {
|
||||
t.Errorf("expected %d tools returned, got %d", len(tools), len(returnedTools))
|
||||
}
|
||||
|
||||
// Should be in content collection mode due to prefill
|
||||
if parser.state != olmo3CollectingContent {
|
||||
t.Errorf("expected state olmo3CollectingContent, got %v", parser.state)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3ThinkParser_InitWithoutPrefill(t *testing.T) {
|
||||
parser := &Olmo3ThinkParser{}
|
||||
|
||||
parser.Init(nil, nil, nil)
|
||||
|
||||
// Should be in thinking collection mode (model always thinks first)
|
||||
if parser.state != olmo3CollectingThink {
|
||||
t.Errorf("expected state olmo3CollectingThink, got %v", parser.state)
|
||||
}
|
||||
}
|
||||
@@ -58,10 +58,6 @@ func ParserForName(name string) Parser {
|
||||
return harmony.NewHarmonyMessageHandler()
|
||||
case "cogito":
|
||||
return &CogitoParser{}
|
||||
case "olmo3-think":
|
||||
return &Olmo3ThinkParser{}
|
||||
case "olmo3":
|
||||
return &Olmo3Parser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,45 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
// marshalWithSpaces marshals v to JSON and adds a space after each ':' and ','
|
||||
// that appears outside of string values. This matches the formatting expected
|
||||
// by certain model architectures.
|
||||
func marshalWithSpaces(v any) ([]byte, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(b)+len(b)/8)
|
||||
inStr, esc := false, false
|
||||
for _, c := range b {
|
||||
if inStr {
|
||||
out = append(out, c)
|
||||
if esc {
|
||||
esc = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' {
|
||||
esc = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inStr = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch c {
|
||||
case '"':
|
||||
inStr = true
|
||||
out = append(out, c)
|
||||
case ':':
|
||||
out = append(out, ':', ' ')
|
||||
case ',':
|
||||
out = append(out, ',', ' ')
|
||||
default:
|
||||
out = append(out, c)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
@@ -1,148 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||
)
|
||||
|
||||
type Olmo3Renderer struct{}
|
||||
|
||||
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var systemMessage *api.Message
|
||||
filteredMessages := make([]api.Message, 0, len(messages))
|
||||
for i, message := range messages {
|
||||
if message.Role == "system" {
|
||||
if systemMessage == nil {
|
||||
systemMessage = &messages[i]
|
||||
}
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
// Render system message
|
||||
if systemMessage != nil {
|
||||
// Custom system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(systemMessage.Content)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
// Default system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(olmo3DefaultSystemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(olmo3WithFunctionsMessage)
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
} else {
|
||||
sb.WriteString(olmo3NoFunctionsMessage)
|
||||
sb.WriteString("<functions></functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<function_calls>")
|
||||
for j, tc := range message.ToolCalls {
|
||||
// Format as function_name(arg1="value1", arg2="value2")
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("(")
|
||||
|
||||
// Get sorted keys for deterministic output
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for k, key := range keys {
|
||||
if k > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
|
||||
}
|
||||
sb.WriteString(")")
|
||||
|
||||
if j < len(message.ToolCalls)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("</function_calls>")
|
||||
}
|
||||
|
||||
// Add end tag unless it's the last message with content only (prefill)
|
||||
if !lastMessage || len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|im_start|>environment\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
needsGenerationPrompt := true
|
||||
if len(filteredMessages) > 0 {
|
||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
|
||||
@@ -1,290 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic without system - adds default system",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "default system with tools - includes function instruction",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls - function call syntax",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check the weather.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather in SF?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"How are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls - newline separated",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Get weather in SF and NYC<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>get_weather(location="San Francisco")` + "\n" +
|
||||
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Book a flight"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "book_flight",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"from": {Type: api.PropertyType{"string"}},
|
||||
"to": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Book a flight<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant prefill - no generation prompt",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,130 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3ThinkDefaultSystemMessage = "You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai."
|
||||
olmo3ThinkNoFunctionsMessage = " You do not currently have access to any functions."
|
||||
)
|
||||
|
||||
type Olmo3ThinkRenderer struct{}
|
||||
|
||||
type olmo3ThinkToolCall struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Type string `json:"type,omitempty"`
|
||||
Function olmo3ThinkToolCallFunc `json:"function"`
|
||||
}
|
||||
|
||||
type olmo3ThinkToolCallFunc struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}
|
||||
|
||||
func (r *Olmo3ThinkRenderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var systemMessage *api.Message
|
||||
filteredMessages := make([]api.Message, 0, len(messages))
|
||||
for i, message := range messages {
|
||||
if message.Role == "system" {
|
||||
if systemMessage == nil {
|
||||
systemMessage = &messages[i]
|
||||
}
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
systemContent := olmo3ThinkDefaultSystemMessage
|
||||
if systemMessage != nil {
|
||||
systemContent = systemMessage.Content
|
||||
}
|
||||
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(systemContent)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(" <functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
} else {
|
||||
sb.WriteString(olmo3ThinkNoFunctionsMessage)
|
||||
sb.WriteString(" <functions></functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
toolCalls := make([]olmo3ThinkToolCall, len(message.ToolCalls))
|
||||
for j, tc := range message.ToolCalls {
|
||||
argsJSON, err := json.Marshal(tc.Function.Arguments)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
toolCalls[j] = olmo3ThinkToolCall{
|
||||
ID: tc.ID,
|
||||
Type: "function",
|
||||
Function: olmo3ThinkToolCallFunc{
|
||||
Name: tc.Function.Name,
|
||||
Arguments: string(argsJSON),
|
||||
},
|
||||
}
|
||||
}
|
||||
toolCallsJSON, err := marshalWithSpaces(toolCalls)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString("<function_calls>")
|
||||
sb.WriteString(string(toolCallsJSON))
|
||||
sb.WriteString("</function_calls>")
|
||||
}
|
||||
|
||||
if !lastMessage {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|im_start|>environment\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
needsGenerationPrompt := true
|
||||
if len(filteredMessages) > 0 {
|
||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n<think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -1,224 +0,0 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3ThinkRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic without system - adds default system",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check the weather.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant. <functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather in SF?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`Let me check the weather.<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}]</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"How are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. <functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Get weather in SF and NYC<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>[{"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"San Francisco\"}"}}, {"id": "call_2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"location\":\"New York\"}"}}]</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
{
|
||||
name: "assistant message only content no tool calls",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Tell me a joke"},
|
||||
{Role: "assistant", Content: "Why did the chicken cross the road?"},
|
||||
{Role: "user", Content: "I don't know, why?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are OLMo, a helpful function-calling AI assistant built by Ai2. Your date cutoff is November 2024, and your model weights are available at https://huggingface.co/allenai. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Tell me a joke<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Why did the chicken cross the road?<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"I don't know, why?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"<think>",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Olmo3ThinkRenderer{}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,11 +1,51 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func marshalWithSpaces(v any) ([]byte, error) {
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make([]byte, 0, len(b)+len(b)/8)
|
||||
inStr, esc := false, false
|
||||
for _, c := range b {
|
||||
if inStr {
|
||||
out = append(out, c)
|
||||
if esc {
|
||||
esc = false
|
||||
continue
|
||||
}
|
||||
if c == '\\' {
|
||||
esc = true
|
||||
continue
|
||||
}
|
||||
if c == '"' {
|
||||
inStr = false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch c {
|
||||
case '"':
|
||||
inStr = true
|
||||
out = append(out, c)
|
||||
case ':':
|
||||
out = append(out, ':', ' ')
|
||||
case ',':
|
||||
out = append(out, ',', ' ')
|
||||
default:
|
||||
out = append(out, c)
|
||||
}
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
type Qwen3VLRenderer struct {
|
||||
isThinking bool
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
// TODO(drifkin): this will be moved to utils in the near future and used by other renderers as well
|
||||
func TestMarshalWithSpaces(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -59,12 +59,6 @@ func rendererForName(name string) Renderer {
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
case "olmo3-think":
|
||||
renderer := &Olmo3ThinkRenderer{}
|
||||
return renderer
|
||||
case "olmo3":
|
||||
renderer := &Olmo3Renderer{}
|
||||
return renderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -300,13 +300,18 @@ func filesForModel(path string) ([]string, error) {
|
||||
}
|
||||
files = append(files, js...)
|
||||
|
||||
// add tokenizer.model if it exists (tokenizer.json is automatically picked up by the previous glob)
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
// only include tokenizer.model is tokenizer.json is not present
|
||||
if !slices.ContainsFunc(files, func(s string) bool {
|
||||
return slices.Contains(strings.Split(s, string(os.PathSeparator)), "tokenizer.json")
|
||||
}) {
|
||||
if tks, _ := glob(filepath.Join(path, "tokenizer.model"), "application/octet-stream"); len(tks) > 0 {
|
||||
// add tokenizer.model if it exists, tokenizer.json is automatically picked up by the previous glob
|
||||
// tokenizer.model might be a unresolved git lfs reference; error if it is
|
||||
files = append(files, tks...)
|
||||
} else if tks, _ := glob(filepath.Join(path, "**/tokenizer.model"), "text/plain"); len(tks) > 0 {
|
||||
// some times tokenizer.model is in a subdirectory (e.g. meta-llama/Meta-Llama-3-8B)
|
||||
files = append(files, tks...)
|
||||
}
|
||||
}
|
||||
|
||||
return files, nil
|
||||
|
||||
@@ -888,37 +888,6 @@ func TestFilesForModel(t *testing.T) {
|
||||
"tokenizer.json",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safetensors with both tokenizer.json and tokenizer.model",
|
||||
setup: func(dir string) error {
|
||||
// Create binary content for tokenizer.model (application/octet-stream)
|
||||
binaryContent := make([]byte, 512)
|
||||
for i := range binaryContent {
|
||||
binaryContent[i] = byte(i % 256)
|
||||
}
|
||||
files := []string{
|
||||
"model-00001-of-00001.safetensors",
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
}
|
||||
for _, file := range files {
|
||||
if err := os.WriteFile(filepath.Join(dir, file), []byte("test content"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Write tokenizer.model as binary
|
||||
if err := os.WriteFile(filepath.Join(dir, "tokenizer.model"), binaryContent, 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantFiles: []string{
|
||||
"model-00001-of-00001.safetensors",
|
||||
"config.json",
|
||||
"tokenizer.json",
|
||||
"tokenizer.model",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "safetensors with consolidated files - prefers model files",
|
||||
setup: func(dir string) error {
|
||||
|
||||
Reference in New Issue
Block a user