diff --git a/convert/convert.go b/convert/convert.go index f6afd8a32..f5cef5567 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -200,6 +200,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &qwen25VLModel{} case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": conv = &qwen3VLModel{} + case "OlmoForCausalLM", "OLMoForCausalLM", "OLMo3ForCausalLM": + conv = &olmoModel{} case "BertModel": conv = &bertModel{} case "CohereForCausalLM": diff --git a/convert/convert_olmo.go b/convert/convert_olmo.go new file mode 100644 index 000000000..848b39475 --- /dev/null +++ b/convert/convert_olmo.go @@ -0,0 +1,82 @@ +package convert + +import ( + "cmp" + + "github.com/ollama/ollama/fs/ggml" +) + +type olmoModel struct { + ModelParameters + + HiddenSize uint32 `json:"hidden_size"` + NumHiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + RMSNormEPS float32 `json:"rms_norm_eps"` + RopeTheta float32 `json:"rope_theta"` + ClampKQV float32 `json:"f_clamp_kqv"` +} + +var _ ModelConverter = (*olmoModel)(nil) + +func (p *olmoModel) KV(t *Tokenizer) ggml.KV { + kv := p.ModelParameters.KV(t) + kv["general.architecture"] = "olmo" + kv["olmo.block_count"] = p.NumHiddenLayers + kv["olmo.context_length"] = p.MaxPositionEmbeddings + kv["olmo.embedding_length"] = p.HiddenSize + kv["olmo.feed_forward_length"] = p.IntermediateSize + kv["olmo.attention.head_count"] = p.NumAttentionHeads + kv["olmo.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads) + + if p.RopeTheta > 0 { + kv["olmo.rope.freq_base"] = p.RopeTheta + } else { + kv["olmo.rope.freq_base"] = float32(10000.0) + } + + if p.RMSNormEPS > 0 { + kv["olmo.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS + } + + if p.ClampKQV > 0 { + kv["olmo.attention.clamp_kqv"] = p.ClampKQV + } + + return kv +} + +func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + for _, t := range ts { + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + + return out +} + +func (p *olmoModel) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "post_attention_layernorm", "ffn_norm", + "model.norm", "output_norm", + "self_attn.q_proj", "attn_q", + "self_attn.k_proj", "attn_k", + "self_attn.v_proj", "attn_v", + "self_attn.o_proj", "attn_output", + "mlp.gate_proj", "ffn_gate", + "mlp.down_proj", "ffn_down", + "mlp.up_proj", "ffn_up", + } +} diff --git a/model/models/models.go b/model/models/models.go index 85bf9a7f3..def2658ce 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -13,6 +13,7 @@ import ( _ "github.com/ollama/ollama/model/models/mistral3" _ "github.com/ollama/ollama/model/models/mllama" _ "github.com/ollama/ollama/model/models/nomicbert" + _ "github.com/ollama/ollama/model/models/olmo" _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" diff --git a/model/models/olmo/model.go b/model/models/olmo/model.go new file mode 100644 index 000000000..2f891935c --- /dev/null +++ b/model/models/olmo/model.go @@ -0,0 +1,188 @@ +package olmo + +import ( + "cmp" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/ml/nn/fast" + "github.com/ollama/ollama/ml/nn/rope" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Options struct { + hiddenSize, numHeads, numKVHeads int + headDim, ropeDim int + eps, ropeBase, ropeScale float32 + clampKQV float32 +} + +type Model struct { + model.Base + model.TextProcessor + + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + Layers []Layer `gguf:"blk"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Options +} + +func New(c fs.Config) (model.Model, error) { + var processor model.TextProcessor + vocabulary := model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Scores: c.Floats("tokenizer.ggml.scores"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + } + + switch c.String("tokenizer.ggml.model") { + case "gpt2": + var pretokenizers []string + switch c.String("tokenizer.ggml.pre") { + case "default": + default: + pretokenizers = []string{ + "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + } + } + processor = model.NewBytePairEncoding(&vocabulary, pretokenizers...) + case "llama": + processor = model.NewSentencePiece(&vocabulary) + default: + return nil, model.ErrUnsupportedTokenizer + } + + m := Model{ + TextProcessor: processor, + Layers: make([]Layer, c.Uint("block_count")), + Options: Options{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), + ropeDim: int(c.Uint("rope.dimension_count")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base", 1e4), + ropeScale: c.Float("rope.scaling.factor", 1), + clampKQV: c.Float("attention.clamp_kqv", 0), + }, + } + + m.Cache = kvcache.NewCausalCache(m.Shift) + + return &m, nil +} + +type SelfAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` + RopeFactors ml.Tensor `gguf:"rope_freqs.weight"` +} + +func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + batchSize := hiddenState.Dim(1) + headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads) + ropeDim := cmp.Or(opts.ropeDim, headDim) + + query := sa.Query.Forward(ctx, hiddenState) + query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) + + key := sa.Key.Forward(ctx, hiddenState) + key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + value := sa.Value.Forward(ctx, hiddenState) + value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + + // Apply RoPE (Rotary Position Embeddings) - OLMo uses NeoX-style rotation + query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, 1./opts.ropeScale, rope.WithFactors(sa.RopeFactors)) + + attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache) + attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize) + return sa.Output.Forward(ctx, attention) +} + +func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { + ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads) + return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, 1./m.ropeScale, rope.WithFactors(m.Layers[layer].SelfAttention.RopeFactors)), nil +} + +type MLP struct { + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` + Gate *nn.Linear `gguf:"ffn_gate"` +} + +func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor { + hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState)) + return mlp.Down.Forward(ctx, hiddenState) +} + +// Layer represents a single transformer layer in OLMo +type Layer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + SelfAttention *SelfAttention + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + MLP *MLP +} + +func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { + residual := hiddenState + + hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts) + + if outputs != nil { + hiddenState = hiddenState.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenState = hiddenState.Add(ctx, residual) + residual = hiddenState + + hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps) + hiddenState = l.MLP.Forward(ctx, hiddenState, opts) + + return hiddenState.Add(ctx, residual) +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions)) + + hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs) + + for i, layer := range m.Layers { + m.Cache.SetLayer(i) + + var outputs ml.Tensor + if i == len(m.Layers)-1 { + outputs = batch.Outputs + } + + hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options) + } + + hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) + return m.Output.Forward(ctx, hiddenState), nil +} + +func init() { + model.Register("olmo", New) + model.Register("olmo2", New) +}