Compare commits

..

8 Commits

Author SHA1 Message Date
Patrick Devine
c10a40db99 parser: tidy up parameter/message parsing
This change addresses how parameters and messages are handled while parsing a Modelfile.
Currently a `MESSAGE` command creates a string that looks like "<role>: <message>" which
then has to be re-parsed, and `PARAMETER` ends up being the default data type which makes
it difficult to add other multi-part commands.

This change introduces a Message and a Parameter type which properly handle properties
such as the role and the name of the parameter.
2025-09-15 18:09:05 -07:00
Daniel Hiltgen
93c64ea1b1 doc: show how to clear the cgo cache (#12298) 2025-09-15 15:45:35 -07:00
Michael Yang
3f6642f6fc model: implement bert in ollama engine (#9080)
* fix truncate

* s/SentencePieceModel/SentencePiece/

* bert

* wordpiece

* refactor pooling

* more tokenizers

* normalize embeddings
2025-09-15 15:35:59 -07:00
Michael Yang
6f7117145f batch: use tensors for outputs (#12185)
this cleans up the model interface slightly without too much impact in
other areas
2025-09-15 14:33:06 -07:00
jmorganca
92b96d54ef Revert "runner: move harmony to runner (#12052)"
This reverts commit 1a558f98e2.
2025-09-12 20:40:14 -03:00
jmorganca
9d56e63dbf Revert "runner: simplify parser entrypoints in runner (#12233)"
This reverts commit 8d6fffaead.
2025-09-12 20:40:14 -03:00
tc-mb
053092185e Fix image cannot be seen with slice image on llama engine
Ollama's recent engine update, llama.cpp, caused all models requiring a slice schema to not display images. As a result, the value of numTokens isn't always the length of the sliced ​​image embed, but rather the end length of the schema. This causes the image embed to not be correctly included during all slice processing.
2025-09-12 16:25:12 -07:00
Daniel Hiltgen
44a6792873 tests: tighten up a few flaky tests (#12271)
Sometimes the context test results are pure emoji's
Thanksgiving has too much variability, so swap for a more straight forward prompt.
2025-09-12 13:59:34 -07:00
39 changed files with 987 additions and 618 deletions

View File

@@ -28,6 +28,7 @@ type bertModel struct {
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
NormEpsilon float32 `json:"norm_epsilon"`
normalizeEmbeddings bool
PoolingType uint32
}
@@ -54,9 +55,11 @@ func (p *bertModel) parseMore(fsys fs.FS) error {
var pooling string
for _, m := range modules {
if m.Type == "sentence_transformers.models.Pooling" {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
break
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
@@ -90,6 +93,7 @@ func (p *bertModel) KV(t *Tokenizer) ggml.KV {
kv["general.architecture"] = "bert"
kv["bert.attention.causal"] = false
kv["bert.pooling_type"] = p.PoolingType
kv["bert.normalize_embeddings"] = p.normalizeEmbeddings
kv["bert.block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers, p.NLayer)

View File

@@ -11,6 +11,10 @@ Then build and run Ollama from the root directory of the repository:
go run . serve
```
> [!NOTE]
> Ollama includes native code compiled with CGO. From time to time these data structures can change and CGO can get out of sync resulting in unexpected crashes. You can force a full build of the native code by running `go clean -cache` first.
## macOS (Apple Silicon)
macOS Apple Silicon supports Metal which is built-in to the Ollama binary. No additional steps are required.

View File

@@ -3,29 +3,15 @@ package harmony
import (
"fmt"
"log/slog"
"slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
)
type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader
@@ -89,28 +75,18 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
func Prefill(lastMessage api.Message) string {
if lastMessage.Role != "assistant" {
return ""
}
switch {
case strings.TrimSpace(lastMessage.Content) != "":
return "<|start|>assistant<|channel|>final<|message|>"
case strings.TrimSpace(lastMessage.Thinking) != "":
return "<|start|>assistant<|channel|>analysis<|message|>"
default:
return ""
}
}
// AddImplicitStartOrPrefill adds an implicit start tag or prefill string if provided
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillString string) {
if strings.TrimSpace(prefillString) != "" {
s.acc.WriteString(prefillString)
} else {
s.AddImplicitStart()
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
if lastMessage != nil && lastMessage.Role == "assistant" {
// handle prefilling conditions
if lastMessage.Content != "" {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return
} else if lastMessage.Thinking != "" {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return
}
}
s.AddImplicitStart()
}
func (s *HarmonyParser) AddContent(content string) []HarmonyEvent {
@@ -289,7 +265,6 @@ type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
ToolParser *HarmonyToolCallAccumulator
}
// NewHarmonyMessageHandler creates a new message handler
@@ -302,16 +277,12 @@ func NewHarmonyMessageHandler() *HarmonyMessageHandler {
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
ToolParser: &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
},
}
}
// AddContent processes the content and returns the content, thinking, and tool content.
// content and thinking are already fully parsed, but tool content still needs to be passed to the tool parser
func (h *HarmonyMessageHandler) AddContent(content string) (string, string, string) {
func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyToolCallAccumulator) (string, string, string) {
contentSb := strings.Builder{}
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
@@ -328,14 +299,14 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
// event.Header.Recipient is the tool name, something like
// "browser.search" for a built-in, or "functions.calc" for a
// custom one
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Thinking
}
case "commentary":
if event.Header.Recipient != "" {
h.state = harmonyMessageState_ToolCalling
h.ToolParser.SetToolName(event.Header.Recipient)
toolParser.SetToolName(event.Header.Recipient)
} else {
h.state = harmonyMessageState_Normal
}
@@ -358,6 +329,13 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
return contentSb.String(), thinkingSb.String(), toolContentSb.String()
}
func (h *HarmonyMessageHandler) CreateToolParser() *HarmonyToolCallAccumulator {
return &HarmonyToolCallAccumulator{
state: harmonyToolCallState_Normal,
currentToolName: nil,
}
}
type harmonyToolCallState int
const (

View File

@@ -3,7 +3,6 @@ package harmony
import (
"fmt"
"reflect"
"strings"
"testing"
)
@@ -536,202 +535,3 @@ func TestFunctionConvertAndAdd(t *testing.T) {
})
}
}
func TestHarmonyMessageHandlerStreamingScenarios(t *testing.T) {
t.Run("thinking_then_content_streams", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
type step struct {
in string
wantContent string
wantThinking string
}
steps := []step{
{in: "<|channel|>analysis<|message|>Thinking...", wantThinking: "Thinking..."},
{in: "<|end|>", wantThinking: ""},
{in: "<|start|>assistant<|message|>Answer", wantContent: "Answer"},
{in: "<|end|>", wantContent: ""},
}
for i, s := range steps {
content, thinking, tool := handler.AddContent(s.in)
if tool != "" {
tp.Add(tool)
}
if content != s.wantContent || thinking != s.wantThinking {
t.Fatalf("step %d: got (content=%q thinking=%q), want (content=%q thinking=%q)", i, content, thinking, s.wantContent, s.wantThinking)
}
}
})
t.Run("content_streams_as_it_arrives", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|start|>assistant<|message|>Hello",
", world",
"!<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
t.Fatalf("unexpected thinking %q", thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Hello", ", world", "!"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("thinking_streams_separately_from_content", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Thinking...",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var got []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
got = append(got, thinking)
}
if content != "" {
got = append(got, content)
}
}
want := []string{"Thinking...", "Answer"}
if !reflect.DeepEqual(got, want) {
t.Fatalf("content pieces mismatch: got %v want %v", got, want)
}
})
t.Run("partial_tags_buffer_until_complete", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|chan",
"nel|>analysis<|mess",
"age|>Deep ",
"thought",
"<|end|>",
"<|start|>assistant<|message|>Done",
"<|end|>",
}
var thinkingPieces []string
var contentPieces []string
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
if thinking != "" {
thinkingPieces = append(thinkingPieces, thinking)
}
if content != "" {
contentPieces = append(contentPieces, content)
}
}
if want := []string{"Deep ", "thought"}; !reflect.DeepEqual(thinkingPieces, want) {
t.Fatalf("thinking pieces mismatch: got %v want %v", thinkingPieces, want)
}
if want := []string{"Done"}; !reflect.DeepEqual(contentPieces, want) {
t.Fatalf("content pieces mismatch: got %v want %v", contentPieces, want)
}
})
t.Run("simple_assistant_after_analysis", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>analysis<|message|>Think",
"<|end|>",
"<|start|>assistant<|message|>Answer",
"<|end|>",
}
var contentSb, thinkingSb strings.Builder
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if tool != "" {
tp.Add(tool)
}
contentSb.WriteString(content)
thinkingSb.WriteString(thinking)
}
if contentSb.String() != "Answer" {
t.Fatalf("content mismatch: got %q want %q", contentSb.String(), "Answer")
}
if thinkingSb.String() != "Think" {
t.Fatalf("thinking mismatch: got %q want %q", thinkingSb.String(), "Think")
}
})
t.Run("tool_call_parsed_and_returned_correctly", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+2\"}<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
t.Run("tool_call_across_chunks", func(t *testing.T) {
handler := NewHarmonyMessageHandler()
handler.HarmonyParser.AddImplicitStart()
tp := handler.ToolParser
inputs := []string{
"<|channel|>commentary to=functions.calculate<|message|>{\"expression\":\"2+",
"2\"}",
"<|end|>",
}
for _, in := range inputs {
content, thinking, tool := handler.AddContent(in)
if content != "" || thinking != "" {
continue
}
if tool != "" {
tp.Add(tool)
}
}
name, args := tp.Drain()
if name == nil || *name != "functions.calculate" {
t.Fatalf("unexpected tool name: %v", name)
}
if got, want := args, "{\"expression\":\"2+2\"}"; got != want {
t.Fatalf("unexpected tool args: got %s want %s", got, want)
}
})
}

View File

@@ -50,7 +50,7 @@ func TestContextExhaustion(t *testing.T) {
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "Write me a story with a ton of emojis?",
Prompt: "Write me a story in english with a lot of emojis",
Stream: &stream,
Options: map[string]any{
"temperature": 0,

View File

@@ -561,7 +561,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
Model: smol,
Prompt: "what is the origin of the US thanksgiving holiday? Be brief but factual in your reply",
Prompt: "how do rainbows form? Be brief but factual in your reply",
Stream: &stream,
KeepAlive: &api.Duration{Duration: 10 * time.Second},
}, {
@@ -579,9 +579,9 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
[][]string{
{"sunlight", "scattering", "interact", "color", "surface", "depth", "red", "orange", "yellow", "absorbs", "wavelength"},
{"soil", "organic", "earth", "black", "tan", "chemical", "processes", "pigments", "particles", "iron oxide", "rust", "air", "water", "mixture", "mixing"},
{"england", "english", "massachusetts", "pilgrims", "colonists", "independence", "british", "feast", "family", "gatherings", "traditions", "turkey", "colonial", "period", "harvest", "agricultural", "european settlers", "american revolution", "civil war", "16th century", "17th century", "native american", "united states", "cultural", "hardship", "autumn", "festival"},
{"water", "droplet", "refracted", "reflect", "color", "spectrum"},
{"fourth", "july", "declaration", "independence"},
{"nitrogen", "oxygen", "carbon", "dioxide"},
{"nitrogen", "oxygen", "carbon", "dioxide", "water", "vapor"},
}
}

View File

@@ -515,33 +515,34 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32,
}
nChunks := C.mtmd_input_chunks_size(ic)
numEmbed := llamaContext.Model().NEmbd()
lastChunkSize := 0
embed := make([][]float32, 0)
for i := range int(nChunks) {
chunk := C.mtmd_input_chunks_get(ic, C.size_t(i))
numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk))
lastChunkSize = numTokens
slog.Debug("chunk tokens", "index", i, "numTokens", numTokens)
// Encode the chunk
if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) {
return nil, errors.New("unable to encode mtmd image chunk")
}
}
// Get the embeddings
embed := make([][]float32, lastChunkSize)
embd := C.mtmd_get_output_embd(c.c)
if nil == embd {
return nil, errors.New("failed to get image embedding")
}
// Get the embeddings for this chunk
chunkEmbed := make([][]float32, numTokens)
chunkEmbd := C.mtmd_get_output_embd(c.c)
if nil == chunkEmbd {
continue
}
// Extend the embedding array for each token
s := unsafe.Slice((*float32)(embd), numEmbed*lastChunkSize)
rows := make([]float32, len(s))
copy(rows, s)
for i := range lastChunkSize {
embed[i] = rows[i*numEmbed : (i+1)*numEmbed]
// Extend the embedding array for each token
s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed)
rows := make([]float32, len(s))
copy(rows, s)
for i := range numTokens {
chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed]
}
embed = append(embed, chunkEmbed...)
}
slog.Debug("image embeddings", "totalEmbeddings", len(embed))
return embed, nil
}

View File

@@ -35,7 +35,6 @@ import (
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/parser"
)
type filteredEnv []string
@@ -1349,9 +1348,7 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
ParserType parser.TokenParserType
PrefillString string
Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
@@ -1378,15 +1375,13 @@ func (d DoneReason) String() string {
}
type CompletionResponse struct {
Content string `json:"content"`
Thinking string `json:"thinking"`
ToolCalls []api.ToolCall `json:"tool_calls"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1504,8 +1499,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
// TODO(parthsareen): token repeat limit is now handled in the runner, this currently support legacy model and can be removed in the future
case strings.TrimSpace(c.Content) == lastToken && c.Content != "":
case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)
@@ -1518,14 +1512,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err()
}
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Done {
fn(c)
return nil
}
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
fn(c)
}
}
}

View File

@@ -416,6 +416,7 @@ type Tensor interface {
AddID(ctx Context, t2, ids Tensor) Tensor
Softmax(ctx Context) Tensor
L2Norm(ctx Context, eps float32) Tensor
LayerNorm(ctx Context, weight, bias Tensor, eps float32) Tensor
RMSNorm(ctx Context, weight Tensor, eps float32) Tensor
Scale(ctx Context, s float64) Tensor

View File

@@ -1205,6 +1205,13 @@ func (t *Tensor) AddID(ctx ml.Context, t2, ids ml.Tensor) ml.Tensor {
}
}
func (t *Tensor) L2Norm(ctx ml.Context, eps float32) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_l2_norm(ctx.(*Context).ctx, t.t, C.float(eps)),
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))
if w != nil {

36
ml/nn/pooling/pooling.go Normal file
View File

@@ -0,0 +1,36 @@
package pooling
import (
"github.com/ollama/ollama/ml"
)
type Type uint32
const (
TypeNone Type = iota
TypeMean
TypeCLS
TypeLast
TypeRank
TypeUnknown = 0xFFFFFFFE
TypeUnspecified = 0xFFFFFFFF
)
func Pooling(ctx ml.Context, hiddenStates ml.Tensor, poolingType Type) ml.Tensor {
switch poolingType {
case TypeNone:
return hiddenStates
case TypeMean:
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
return hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
case TypeCLS:
return hiddenStates.View(ctx, 0, hiddenStates.Dim(0))
case TypeLast:
panic("not implemented")
case TypeRank:
panic("not implemented")
default:
panic("not implemented")
}
}

View File

@@ -54,10 +54,9 @@ type Batch struct {
// Inputs is the input tokens, including placeholders for multimodal inputs.
Inputs ml.Tensor
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs ml.Tensor
// Positions is the position for each Input, relative to its sequence. Equal
// in length to Inputs.
@@ -66,7 +65,8 @@ type Batch struct {
// Sequences is the sequence for each Input. Equal in length to Inputs.
Sequences []int
// Outputs are the set of indicies into Inputs for which output data should
// be returned.
Outputs []int32
// Multimodal is a set of multimodal embeddings previously created by
// EncodeMultimodal, along with an index into Inputs. Unused for text-only
// models or for batches without multimodal elements.
Multimodal []MultimodalIndex
}

View File

@@ -24,7 +24,11 @@ import (
"github.com/ollama/ollama/model/input"
)
var ErrNoVisionModel = errors.New("this model is missing data required for image input")
var (
ErrNoVisionModel = errors.New("this model is missing data required for image input")
ErrUnsupportedModel = errors.New("model not supported")
ErrUnsupportedTokenizer = errors.New("tokenizer not supported")
)
// Model implements a specific model architecture, defining the forward pass and any model-specific configuration
type Model interface {
@@ -242,7 +246,7 @@ func setPointer(base Base, v reflect.Value, tags []Tag) {
vv = vv.Elem()
}
vv = vv.Elem()
vv = reflect.Indirect(vv)
if v.IsNil() {
vv = reflect.New(v.Type().Elem()).Elem()
}

181
model/models/bert/model.go Normal file
View File

@@ -0,0 +1,181 @@
package bert
import (
"cmp"
"math"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
model.TextProcessor
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
TypeEmbedding *nn.Embedding `gguf:"token_types"`
PositionEmbedding *nn.Embedding `gguf:"position_embd"`
TokenEmbeddingNorm *nn.LayerNorm `gguf:"token_embd_norm"`
Layers []EncoderLayer `gguf:"blk"`
Options
}
// Forward implements model.Model.
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
hiddenStates := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenStates = hiddenStates.Add(ctx, m.TypeEmbedding.Weight.View(ctx, 0, m.hiddenSize))
hiddenStates = hiddenStates.Add(ctx, m.PositionEmbedding.Forward(ctx, ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))))
hiddenStates = m.TokenEmbeddingNorm.Forward(ctx, hiddenStates, m.eps)
for _, layer := range m.Layers {
hiddenStates = layer.Forward(ctx, hiddenStates, &m.Options)
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
if m.normalize {
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
}
return hiddenStates, nil
}
type EncoderLayer struct {
*Attention
AttentionNorm *nn.LayerNorm `gguf:"attn_output_norm"`
*MLP
MLPNorm *nn.LayerNorm `gguf:"layer_output_norm"`
}
func (e *EncoderLayer) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
// Attention
residual := hiddenStates
hiddenStates = e.Attention.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.AttentionNorm.Forward(ctx, hiddenStates, opts.eps)
// MLP
residual = hiddenStates
hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts)
hiddenStates = hiddenStates.Add(ctx, residual)
hiddenStates = e.MLPNorm.Forward(ctx, hiddenStates, opts.eps)
return hiddenStates
}
type Attention struct {
Query *nn.Linear `gguf:"attn_q"`
QueryNorm *nn.LayerNorm `gguf:"attn_q_norm"`
Key *nn.Linear `gguf:"attn_k"`
KeyNorm *nn.LayerNorm `gguf:"attn_k_norm"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
}
func (a *Attention) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
batchSize := hiddenStates.Dim(1)
query := a.Query.Forward(ctx, hiddenStates)
if a.QueryNorm != nil {
query = a.QueryNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize)
key := a.Key.Forward(ctx, hiddenStates)
if a.KeyNorm != nil {
key = a.KeyNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
value := a.Value.Forward(ctx, hiddenStates)
value = value.Reshape(ctx, opts.headDim(), cmp.Or(opts.numKVHeads, opts.numHeads), batchSize)
attention := nn.Attention(ctx, query, key, value, 1/math.Sqrt(float64(opts.headDim())), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return a.Output.Forward(ctx, attention)
}
type MLP struct {
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (m *MLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *Options) ml.Tensor {
return m.Down.Forward(ctx, m.Up.Forward(ctx, hiddenStates).GELU(ctx))
}
type Options struct {
hiddenSize,
numHeads,
numKVHeads,
keyLength,
valueLength int
poolingType pooling.Type
eps float32
normalize bool
}
func (o Options) headDim() int {
return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads)
}
func New(c fs.Config) (model.Model, error) {
var processor model.TextProcessor
switch c.String("tokenizer.ggml.model", "bert") {
case "bert":
processor = model.NewWordPiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
BOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.cls_token_id"),
c.Uint("tokenizer.ggml.bos_token_id"),
)),
},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", true),
EOS: []int32{
int32(cmp.Or(
c.Uint("tokenizer.ggml.separator_token_id"),
//nolint:misspell
// NOTE: "seperator_token_id" is a typo in model metadata but we need to
// support it for compatibility.
c.Uint("tokenizer.ggml.seperator_token_id"),
c.Uint("tokenizer.ggml.eos_token_id"),
)),
},
},
)
default:
return nil, model.ErrUnsupportedTokenizer
}
return &Model{
TextProcessor: processor,
Layers: make([]EncoderLayer, c.Uint("block_count")),
Options: Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_epsilon"),
poolingType: pooling.Type(c.Uint("pooling_type")),
normalize: c.Bool("normalize_embeddings", true),
},
}, nil
}
func init() {
model.Register("bert", New)
model.Register("bert_embed", New)
}

View File

@@ -24,7 +24,7 @@ type Options struct {
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
@@ -40,7 +40,7 @@ const (
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -176,7 +176,6 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.Options.hiddenSize)))
@@ -193,7 +192,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)

View File

@@ -1,49 +1,38 @@
package gemma3
import (
"errors"
"github.com/ollama/ollama/fs"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/ml/nn/pooling"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type embedModel struct {
model.Base
model.SentencePieceModel
model.SentencePiece
*TextModel
PoolingType uint32
poolingType pooling.Type
Dense [2]*nn.Linear `gguf:"dense"`
}
func (m *embedModel) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
batch.Outputs = batch.Positions // return all positions
hiddenStates := m.TextModel.Forward(ctx, batch, m.Cache)
switch m.PoolingType {
case 0: // None
case 1: // Mean
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
default:
return nil, errors.New("unsupported pooling type")
}
hiddenStates = pooling.Pooling(ctx, hiddenStates, m.poolingType)
for _, dense := range m.Dense {
hiddenStates = dense.Forward(ctx, hiddenStates)
}
hiddenStates = hiddenStates.L2Norm(ctx, 1e-12)
return hiddenStates, nil
}
func newEmbedModel(c fs.Config) (model.Model, error) {
m := &embedModel{
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),
@@ -61,7 +50,7 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
},
),
TextModel: newTextModel(c),
PoolingType: c.Uint("pooling_type", 0),
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
}
m.Cache = kvcache.NewWrapperCache(

View File

@@ -16,7 +16,7 @@ import (
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
*VisionModel `gguf:"v"`
*TextModel
@@ -55,7 +55,7 @@ func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, i
func New(c fs.Config) (model.Model, error) {
m := Model{
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -161,7 +161,6 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs,
func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cache) ml.Tensor {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextConfig.hiddenSize)))
@@ -194,7 +193,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
lastLayerOutputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, i, hiddenState, positions, lastLayerOutputs, cache, m.TextConfig)

View File

@@ -10,7 +10,7 @@ import (
type Model struct {
model.Base
model.SentencePieceModel
model.SentencePiece
*TextModel
}
@@ -23,7 +23,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
func New(c fs.Config) (model.Model, error) {
m := Model{
TextModel: newTextModel(c),
SentencePieceModel: model.NewSentencePieceModel(
SentencePiece: model.NewSentencePiece(
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@@ -83,7 +83,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
hiddenStates = hiddenStates.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx).Mean(ctx)
hiddenStates = hiddenStates.Permute(ctx, 2, 0, 1, 3).Contiguous(ctx)
hiddenStates = hiddenStates.Rows(ctx, ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs)))
hiddenStates = hiddenStates.Rows(ctx, batch.Outputs)
hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, m.eps)
return m.Output.Forward(ctx, hiddenStates), nil

View File

@@ -41,8 +41,8 @@ func (m *Transformer) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, err
}
var outputs ml.Tensor
if len(batch.Outputs) > 0 && i == len(m.TransformerBlocks)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
if i == len(m.TransformerBlocks)-1 {
outputs = batch.Outputs
}
hiddenStates = block.Forward(ctx, hiddenStates, positions, outputs, one, m.Cache, &m.Options)

View File

@@ -160,7 +160,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m.Options)

View File

@@ -176,9 +176,7 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
}
func init() {

View File

@@ -107,10 +107,9 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
}
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
// TODO: attention mask, cross attention mask
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
}
func init() {

View File

@@ -1,6 +1,7 @@
package models
import (
_ "github.com/ollama/ollama/model/models/bert"
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/gemma3n"

View File

@@ -111,7 +111,7 @@ func (m Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, &m.Options)

View File

@@ -140,9 +140,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromIntSlice(batch.Positions, len(batch.Positions))
outputs := ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
return m.TextModel.Forward(ctx, batch.Inputs, positions, outputs, batch, m.Cache)
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache)
}
func init() {

View File

@@ -165,7 +165,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
var outputs ml.Tensor
if i == len(m.Layers)-1 {
outputs = ctx.Input().FromIntSlice(batch.Outputs, len(batch.Outputs))
outputs = batch.Outputs
}
hiddenStates = layer.Forward(ctx, hiddenStates, positions, outputs, m.Cache, m.Options)

View File

@@ -12,18 +12,18 @@ import (
const spmWhitespaceSep = "▁"
type SentencePieceModel struct {
type SentencePiece struct {
maxTokenLen int
vocab *Vocabulary
}
var _ TextProcessor = (*SentencePieceModel)(nil)
var _ TextProcessor = (*SentencePiece)(nil)
func (spm SentencePieceModel) Vocabulary() *Vocabulary {
func (spm SentencePiece) Vocabulary() *Vocabulary {
return spm.vocab
}
func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
func NewSentencePiece(vocab *Vocabulary) SentencePiece {
logutil.Trace("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
counter := map[int]int{}
@@ -42,17 +42,17 @@ func NewSentencePieceModel(vocab *Vocabulary) SentencePieceModel {
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
"max token len", maxTokenLen)
return SentencePieceModel{
return SentencePiece{
maxTokenLen: maxTokenLen,
vocab: vocab,
}
}
func (spm SentencePieceModel) Is(id int32, special Special) bool {
func (spm SentencePiece) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) {
fragments := []fragment{{value: s}}
for _, special := range spm.vocab.SpecialVocabulary() {
id := spm.vocab.Encode(special)
@@ -218,7 +218,7 @@ func (q *queue) Pop() interface{} {
return item
}
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
func (spm SentencePiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for _, id := range ids {
data := spm.vocab.Decode(id)

View File

@@ -12,7 +12,7 @@ import (
"github.com/ollama/ollama/convert/sentencepiece"
)
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
func loadSentencePieceVocab(t *testing.T) SentencePiece {
t.Helper()
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
@@ -45,7 +45,7 @@ func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
}
}
return NewSentencePieceModel(&v)
return NewSentencePiece(&v)
}
func TestSentencePieceEncode(t *testing.T) {
@@ -115,7 +115,7 @@ func TestSentencePieceEncode(t *testing.T) {
})
}
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
func TestSentencePieceDecodeByteTokens(t *testing.T) {
vocab := &Vocabulary{
Values: []string{
"normal",
@@ -134,7 +134,7 @@ func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
Scores: []float32{0, 0, 0, 0, 0},
}
spm := NewSentencePieceModel(vocab)
spm := NewSentencePiece(vocab)
tests := []struct {
name string

167
model/wordpiece.go Normal file
View File

@@ -0,0 +1,167 @@
package model
import (
"fmt"
"iter"
"strings"
"unicode"
"github.com/ollama/ollama/logutil"
)
type WordPiece struct {
vocab *Vocabulary
}
// ggmlPrefix is the prefix used by GGML vocabularies to indicate word boundaries.
// this differs from original word piece which uses "##" to indicate subwords.
const ggmlPrefix = "▁"
var wordPieceReplacer = strings.NewReplacer(
" .", ".",
" ?", "?",
" !", "!",
" ,", ",",
" ' ", "'",
" n't", "n't",
" 'm", "'m",
" do not", " don't",
" 's", "'s",
" 've", "'ve",
" 're", "'re",
)
// Decode implements TextProcessor.
func (wpm WordPiece) Decode(ids []int32) (string, error) {
var sb strings.Builder
for i, id := range ids {
if id < 0 || int(id) >= len(wpm.vocab.Values) {
return "", fmt.Errorf("invalid token id: %d", id)
}
var separator string
piece := wpm.vocab.Values[id]
if i > 0 &&
(strings.HasPrefix(piece, ggmlPrefix) ||
(strings.HasPrefix(piece, "[") && strings.HasSuffix(piece, "]"))) {
separator = " "
}
sb.WriteString(wordPieceReplacer.Replace(separator + strings.TrimPrefix(piece, ggmlPrefix)))
}
return sb.String(), nil
}
// words splits a string into words, treating CJK characters as separate words.
// TODO: this is specifically for BERT and may need to be adjusted or refactored for other models.
func (wpm WordPiece) words(s string) iter.Seq[string] {
return func(yield func(string) bool) {
runes := make([]rune, 0, len(s)*3)
for _, r := range s {
switch {
case r >= 0x4E00 && r <= 0x9FFF,
r >= 0x3400 && r <= 0x4DBF,
r >= 0x20000 && r <= 0x2A6DF,
r >= 0x2A700 && r <= 0x2B73F,
r >= 0x2B740 && r <= 0x2B81F,
r >= 0x2B820 && r <= 0x2CEAF,
r >= 0xF900 && r <= 0xFAFF,
r >= 0x2F800 && r <= 0x2FA1F:
runes = append(runes, ' ', r, ' ')
default:
runes = append(runes, r)
}
}
for w := range strings.FieldsFuncSeq(string(runes), unicode.IsSpace) {
// split on but keep punctuation
var start int
for start < len(w) {
end := strings.IndexFunc(w[start:], unicode.IsPunct)
if end < 0 {
end = len(w) - start
} else if end == 0 {
end = 1
}
if !yield(w[start : start+end]) {
return
}
start += end
}
}
}
}
// Encode implements TextProcessor.
func (wpm WordPiece) Encode(s string, addSpecial bool) ([]int32, error) {
var ids []int32
// TODO: use [UNK] from config
unk := wpm.vocab.Encode("[UNK]")
for word := range wpm.words(s) {
var start int
var pieces []int32
for start < len(word) {
end := len(word)
var piece int32
for start < end {
subword := word[start:end]
if start == 0 {
subword = ggmlPrefix + subword
}
// TODO: some models might not want [ToLower]
piece = wpm.vocab.Encode(strings.ToLower(subword))
if piece >= 0 {
break
}
end--
}
if piece < 0 {
// Unknown token
pieces = pieces[:0]
break
}
pieces = append(pieces, piece)
start = end
}
if len(pieces) > 0 {
ids = append(ids, pieces...)
} else {
ids = append(ids, unk)
}
}
if addSpecial && len(ids) > 0 {
ids = wpm.vocab.addSpecials(ids)
}
logutil.Trace("encoded", "string", s, "ids", ids)
return ids, nil
}
// Is implements TextProcessor.
func (wpm WordPiece) Is(id int32, special Special) bool {
return wpm.vocab.Is(id, special)
}
// Vocabulary implements TextProcessor.
func (wpm WordPiece) Vocabulary() *Vocabulary {
return wpm.vocab
}
var _ TextProcessor = (*WordPiece)(nil)
func NewWordPiece(vocab *Vocabulary) WordPiece {
return WordPiece{
vocab: vocab,
}
}

51
model/wordpiece_test.go Normal file
View File

@@ -0,0 +1,51 @@
package model
import (
"slices"
"testing"
"github.com/google/go-cmp/cmp"
)
func TestWordPiece(t *testing.T) {
wpm := NewWordPiece(
&Vocabulary{
Values: []string{"[UNK]", "[CLS]", "[SEP]", "▁hello", "▁world", "s", "▁!", "▁@", "▁#"},
AddBOS: true,
AddEOS: true,
BOS: []int32{1},
EOS: []int32{2},
})
ids, err := wpm.Encode("Hello world!", true)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff([]int32{1, 3, 4, 6, 2}, ids); diff != "" {
t.Errorf("unexpected ids (-want +got):\n%s", diff)
}
words, err := wpm.Decode(ids)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff("[CLS] hello world! [SEP]", words); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}
func TestWordPieceWords(t *testing.T) {
var wpm WordPiece
basic := slices.Collect(wpm.words("Hey friend! How are you?!?"))
if diff := cmp.Diff([]string{"Hey", "friend", "!", "How", "are", "you", "?", "!", "?"}, basic); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
chinese := slices.Collect(wpm.words("野口里佳 Noguchi Rika"))
if diff := cmp.Diff([]string{"野", "口", "里", "佳", "Noguchi", "Rika"}, chinese); diff != "" {
t.Errorf("unexpected words (-want +got):\n%s", diff)
}
}

View File

@@ -62,14 +62,15 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
for _, c := range f.Commands {
switch c.Name {
case "model":
path, err := expandPath(c.Args, relativeDir)
name := c.Args.(string)
path, err := expandPath(name, relativeDir)
if err != nil {
return nil, err
}
digestMap, err := fileDigestMap(path)
if errors.Is(err, os.ErrNotExist) {
req.From = c.Args
req.From = name
continue
} else if err != nil {
return nil, err
@@ -83,7 +84,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
}
}
case "adapter":
path, err := expandPath(c.Args, relativeDir)
adapter := c.Args.(string)
path, err := expandPath(adapter, relativeDir)
if err != nil {
return nil, err
}
@@ -95,21 +97,25 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
req.Adapters = digestMap
case "template":
req.Template = c.Args
template := c.Args.(string)
req.Template = template
case "system":
req.System = c.Args
system := c.Args.(string)
req.System = system
case "license":
licenses = append(licenses, c.Args)
license := c.Args.(string)
licenses = append(licenses, license)
case "message":
role, msg, _ := strings.Cut(c.Args, ": ")
messages = append(messages, api.Message{Role: role, Content: msg})
default:
msg := c.Args.(*Message)
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
case "parameter":
if slices.Contains(deprecatedParameters, c.Name) {
fmt.Printf("warning: parameter %s is deprecated\n", c.Name)
fmt.Printf("warning: parameter '%s' is deprecated\n", c.Name)
break
}
ps, err := api.FormatParams(map[string][]string{c.Name: {c.Args}})
param := c.Args.(*Parameter)
ps, err := api.FormatParams(map[string][]string{param.Name: {param.Value}})
if err != nil {
return nil, err
}
@@ -123,6 +129,8 @@ func (f Modelfile) CreateRequest(relativeDir string) (*api.CreateRequest, error)
params[k] = v
}
}
default:
return nil, fmt.Errorf("warning: unknown command '%s'", c.Name)
}
}
@@ -312,7 +320,17 @@ func filesForModel(path string) ([]string, error) {
type Command struct {
Name string
Args string
Args any
}
type Parameter struct {
Name string
Value string
}
type Message struct {
Role string
Content string
}
func (c Command) String() string {
@@ -321,12 +339,16 @@ func (c Command) String() string {
case "model":
fmt.Fprintf(&sb, "FROM %s", c.Args)
case "license", "template", "system", "adapter":
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(c.Args))
data := c.Args.(string)
fmt.Fprintf(&sb, "%s %s", strings.ToUpper(c.Name), quote(data))
case "message":
role, message, _ := strings.Cut(c.Args, ": ")
fmt.Fprintf(&sb, "MESSAGE %s %s", role, quote(message))
data := c.Args.(*Message)
fmt.Fprintf(&sb, "MESSAGE %s %s", data.Role, quote(data.Content))
case "parameter":
data := c.Args.(*Parameter)
fmt.Fprintf(&sb, "PARAMETER %s %s", data.Name, quote(data.Value))
default:
fmt.Fprintf(&sb, "PARAMETER %s %s", c.Name, quote(c.Args))
fmt.Printf("unknown command '%s'\n", c.Name)
}
return sb.String()
@@ -366,7 +388,6 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
var curr state
var currLine int = 1
var b bytes.Buffer
var role string
var f Modelfile
@@ -413,6 +434,7 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
case "parameter":
// transition to stateParameter which sets command name
next = stateParameter
cmd.Name = s
case "message":
// transition to stateMessage which validates the message role
next = stateMessage
@@ -421,16 +443,37 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
cmd.Name = s
}
case stateParameter:
cmd.Name = b.String()
s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
cmd.Args = &Parameter{
Name: s,
}
case stateMessage:
if !isValidMessageRole(b.String()) {
s, ok := unquote(strings.TrimSpace(b.String()))
if !ok || isSpace(r) {
if _, err := b.WriteRune(r); err != nil {
return nil, err
}
continue
}
if !isValidMessageRole(s) {
return nil, &ParserError{
LineNumber: currLine,
Msg: errInvalidMessageRole.Error(),
}
}
role = b.String()
cmd.Args = &Message{
Role: s,
}
case stateComment, stateNil:
// pass
case stateValue:
@@ -443,12 +486,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
continue
}
if role != "" {
s = role + ": " + s
role = ""
switch cmd.Name {
case "parameter":
p := cmd.Args.(*Parameter)
p.Value = s
case "message":
m := cmd.Args.(*Message)
m.Content = s
default:
cmd.Args = s
}
cmd.Args = s
f.Commands = append(f.Commands, cmd)
}
@@ -473,11 +520,16 @@ func ParseFile(r io.Reader) (*Modelfile, error) {
return nil, io.ErrUnexpectedEOF
}
if role != "" {
s = role + ": " + s
switch cmd.Name {
case "parameter":
c := cmd.Args.(*Parameter)
c.Value = s
case "message":
c := cmd.Args.(*Message)
c.Content = s
default:
cmd.Args = s
}
cmd.Args = s
f.Commands = append(f.Commands, cmd)
default:
return nil, io.ErrUnexpectedEOF

View File

@@ -47,8 +47,8 @@ TEMPLATE """{{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: "model1"},
{Name: "adapter", Args: "adapter1"},
{Name: "license", Args: "MIT"},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: "{{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|>"},
}
@@ -80,8 +80,8 @@ TEMPLATE """ {{ if .System }}<|start_header_id|>system<|end_header_id|>
{Name: "model", Args: " model 1"},
{Name: "adapter", Args: "adapter3"},
{Name: "license", Args: "MIT "},
{Name: "param1", Args: "value1"},
{Name: "param2", Args: "value2"},
{Name: "parameter", Args: &Parameter{"param1", "value1"}},
{Name: "parameter", Args: &Parameter{"param2", "value2"}},
{Name: "template", Args: " {{ if .System }}<|start_header_id|>system<|end_header_id|>\n\n{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>\n\n{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>\n\n{{ .Response }}<|eot_id|> "},
}
@@ -101,7 +101,7 @@ func TestParseFileFrom(t *testing.T) {
},
{
"FROM \"FOO BAR\"\nPARAMETER param1 value1",
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "param1", Args: "value1"}},
[]Command{{Name: "model", Args: "FOO BAR"}, {Name: "parameter", Args: &Parameter{"param1", "value1"}}},
nil,
},
{
@@ -149,12 +149,12 @@ func TestParseFileFrom(t *testing.T) {
},
{
"PARAMETER param1 value1\nFROM foo",
[]Command{{Name: "param1", Args: "value1"}, {Name: "model", Args: "foo"}},
[]Command{{Name: "parameter", Args: &Parameter{"param1", "value1"}}, {Name: "model", Args: "foo"}},
nil,
},
{
"PARAMETER what the \nFROM lemons make lemonade ",
[]Command{{Name: "what", Args: "the"}, {Name: "model", Args: "lemons make lemonade"}},
[]Command{{Name: "parameter", Args: &Parameter{"what", "the"}}, {Name: "model", Args: "lemons make lemonade"}},
nil,
},
}
@@ -211,7 +211,7 @@ MESSAGE system You are a file parser. Always parse things.
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."},
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
},
nil,
},
@@ -221,7 +221,7 @@ FROM foo
MESSAGE system You are a file parser. Always parse things.`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."},
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
},
nil,
},
@@ -234,9 +234,9 @@ MESSAGE assistant Hello, I want to parse all the things!
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a file parser. Always parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
{Name: "message", Args: &Message{"system", "You are a file parser. Always parse things."}},
{Name: "message", Args: &Message{"user", "Hey there!"}},
{Name: "message", Args: &Message{"assistant", "Hello, I want to parse all the things!"}},
},
nil,
},
@@ -244,12 +244,12 @@ MESSAGE assistant Hello, I want to parse all the things!
`
FROM foo
MESSAGE system """
You are a multiline file parser. Always parse things.
You are a multiline file "parser". Always parse things.
"""
`,
[]Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: \nYou are a multiline file parser. Always parse things.\n"},
{Name: "message", Args: &Message{"system", "\nYou are a multiline file \"parser\". Always parse things.\n"}},
},
nil,
},
@@ -514,7 +514,7 @@ func TestParseFileParameters(t *testing.T) {
assert.Equal(t, []Command{
{Name: "model", Args: "foo"},
{Name: v.name, Args: v.value},
{Name: "parameter", Args: &Parameter{v.name, v.value}},
}, modelfile.Commands)
})
}
@@ -617,8 +617,8 @@ SYSTEM You are a utf16 file.
expected := []Command{
{Name: "model", Args: "bob"},
{Name: "param1", Args: "1"},
{Name: "param2", Args: "4096"},
{Name: "parameter", Args: &Parameter{"param1", "1"}},
{Name: "parameter", Args: &Parameter{"param2", "4096"}},
{Name: "system", Args: "You are a utf16 file."},
}

View File

@@ -1,126 +0,0 @@
package parser
import (
"encoding/json"
"errors"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/harmony"
)
type TokenParserType int
const (
TokenParserTypeDefault TokenParserType = iota
TokenParserTypeHarmony
)
type TokenParser struct {
messageHandler MessageHandler
parserEngine ParserInternals
toolParser ToolParser
lastToken string
tokenRepeat int
repeatLimit int
}
const defaultTokenRepeatLimit = 30
type MessageHandler interface {
AddContent(token string) (content, thinking string, toolContent string)
}
type ParserInternals interface {
AddImplicitStartOrPrefill(prefillString string)
}
type ToolParser interface {
Add(token string)
Drain() (toolName *string, toolContent string)
}
// Default implementation for the TokenParser interface as a no-op passthrough
type defaultMessageHandler struct{}
func (defaultMessageHandler) AddContent(token string) (string, string, string) {
return token, "", ""
}
type defaultEngine struct{}
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
type defaultToolParser struct{}
func (defaultToolParser) Add(token string) {}
func (defaultToolParser) Drain() (*string, string) { return nil, "" }
func NewTokenParser(parserType TokenParserType, prefillString string) TokenParser {
switch parserType {
case TokenParserTypeHarmony:
harmonyMessageHandler := harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(prefillString)
return TokenParser{
messageHandler: harmonyMessageHandler,
parserEngine: harmonyMessageHandler.HarmonyParser,
toolParser: harmonyMessageHandler.ToolParser,
repeatLimit: defaultTokenRepeatLimit,
}
default:
return TokenParser{
messageHandler: defaultMessageHandler{},
parserEngine: defaultEngine{},
toolParser: defaultToolParser{},
repeatLimit: 30,
}
}
}
func (p *TokenParser) AddContent(token string) (string, string, error) {
if p.repeatLimitReached(token) {
return "", "", errors.New("token repeat limit reached")
}
content, thinking, toolContent := p.messageHandler.AddContent(token)
p.toolParser.Add(toolContent)
return content, thinking, nil
}
// repeatLimitReached updates repeat counters and returns true if the repeat limit is reached.
func (p *TokenParser) repeatLimitReached(token string) bool {
if p == nil {
return false
}
trimmed := strings.TrimSpace(token)
if trimmed == p.lastToken {
p.tokenRepeat++
} else {
p.tokenRepeat = 0
}
p.lastToken = trimmed
return p.tokenRepeat >= p.repeatLimit
}
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
func (p *TokenParser) Drain() []api.ToolCall {
toolName, toolContent := p.toolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
return nil
}
return []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
},
}
}
return nil
}

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample"
@@ -468,6 +467,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batchOutputs []int32
var batch input.Batch
resumeSeq := -1
@@ -550,9 +550,9 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
seq.iBatch = len(batchOutputs)
if i+1 == len(seq.inputs) || seq.embeddingOnly {
batchOutputs = append(batchOutputs, int32(len(batchInputs)-1))
}
logutil.Trace("forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -577,6 +577,7 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
// Actual batchInputs values will be injected into the batch.Inputs tensor before calling Compute
batch.Inputs = nextBatch.ctx.Input().Empty(ml.DTypeI32, len(batchInputs))
batch.Outputs = nextBatch.ctx.Input().FromIntSlice(batchOutputs, len(batchOutputs))
nextBatch.modelOutput, err = model.Forward(nextBatch.ctx, s.model, batch)
if err != nil {
err = fmt.Errorf("failed to build graph: %w", err)
@@ -704,8 +705,8 @@ func (s *Server) computeBatch(activeBatch batchState) {
}
// sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
vocabSize := len(outputs) / activeBatch.batch.Outputs.Dim(0)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
@@ -781,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
@@ -873,18 +872,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
var thinking string
var err error
content, thinking, err = tokenParser.AddContent(content)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
close(seq.quit)
return
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
Thinking: thinking,
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
@@ -893,9 +882,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
toolCalls := tokenParser.Drain()
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
@@ -1061,12 +1048,8 @@ func (s *Server) reserveWorstCaseGraph() error {
batch.Positions[i] = int32(i)
}
batch.Outputs = make([]int32, s.parallel)
for i := range batch.Outputs {
batch.Outputs[i] = int32(i)
}
batch.Inputs = ctx.Input().FromIntSlice(batchInputs, len(batchInputs))
batch.Outputs = ctx.Input().Empty(ml.DTypeI32, s.parallel)
cache := s.model.Config().Cache
if cache != nil {

View File

@@ -36,7 +36,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -47,6 +46,18 @@ import (
"github.com/ollama/ollama/version"
)
func shouldUseHarmony(model *Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func experimentEnabled(name string) bool {
return slices.Contains(strings.Split(os.Getenv("OLLAMA_EXPERIMENT"), ","), name)
}
@@ -196,17 +207,13 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
var parserType parser.TokenParserType
useHarmony := shouldUseHarmony(m) && !req.Raw
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
var functionNameMap *harmony.FunctionNameMap
if useHarmony {
functionNameMap = harmony.NewFunctionNameMap()
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.HarmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
// Validate Think value: string values currently only allowed for gptoss models
@@ -350,19 +357,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: cr.Content,
Done: cr.Done,
Thinking: cr.Thinking,
ToolCalls: cr.ToolCalls,
Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration,
@@ -371,22 +375,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if res.Done {
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
if useHarmony {
for i, tool := range res.ToolCalls {
res.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
}
if res.Response != "" || res.Thinking != "" || len(res.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}
if thinkingState != nil {
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.Response = content
res.Thinking = thinking
harmonyToolParser.Add(toolContent)
} else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@@ -397,6 +391,30 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
@@ -470,7 +488,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
@@ -537,7 +554,16 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
ctxLen--
}
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
ctxLen--
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -1599,27 +1625,27 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
var parserType parser.TokenParserType
if useHarmony {
parserType = parser.TokenParserTypeHarmony
} else {
parserType = parser.TokenParserTypeDefault
}
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
useHarmony := shouldUseHarmony(m)
processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillString string
// TODO(parthsareen): this can be abstracted to not be model specific and potentially moved to the runner
if useHarmony {
prefillString = harmony.Prefill(msgs[len(msgs)-1])
functionNameMap = harmony.NewFunctionNameMap()
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools)
for i, tool := range processedTools {
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
processedTools[i].Function.Name = harmonyMessageHandler.FunctionNameMap.ConvertAndAdd(tool.Function.Name)
}
}
@@ -1672,17 +1698,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
ParserType: parserType,
PrefillString: prefillString,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -1698,13 +1722,31 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if useHarmony {
for i, tool := range res.Message.ToolCalls {
res.Message.ToolCalls[i].Function.Name = functionNameMap.OriginalFromConverted(tool.Function.Name)
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}

View File

@@ -7,6 +7,7 @@ import (
"bytes"
"context"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
@@ -117,7 +118,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "content streams as it arrives",
steps: []step{
{
input: llm.CompletionResponse{Content: "Hello", Done: false},
input: llm.CompletionResponse{Content: "<|message|>Hello", Done: false},
wantContent: "Hello",
},
{
@@ -125,7 +126,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
wantContent: ", world",
},
{
input: llm.CompletionResponse{Content: "!", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "!<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "!",
},
},
@@ -134,15 +135,20 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "thinking streams separately from content",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Thinking...", Done: false},
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Thinking...", Done: false},
wantThinking: "Thinking...",
},
{
input: llm.CompletionResponse{Content: "Answer", Done: false},
wantContent: "Answer",
input: llm.CompletionResponse{Content: "<|end|>", Done: false},
// No output expected - just closes the analysis message and resets state to normal
},
{
input: llm.CompletionResponse{Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Answer", Done: false},
wantContent: "Answer", // After message end, state is reset to normal
},
{
input: llm.CompletionResponse{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
// No output expected - just closes the assistant message
},
},
},
@@ -150,16 +156,24 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "partial tags buffer until complete",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Deep ", Done: false},
input: llm.CompletionResponse{Content: "<|chan", Done: false},
// No output - partial tag
},
{
input: llm.CompletionResponse{Content: "nel|>analysis<|mess", Done: false},
// No output - still building tags
},
{
input: llm.CompletionResponse{Content: "age|>Deep ", Done: false},
wantThinking: "Deep ",
},
{
input: llm.CompletionResponse{Thinking: "thought", Done: false},
input: llm.CompletionResponse{Content: "thought<|end|>", Done: false},
wantThinking: "thought",
},
{
input: llm.CompletionResponse{Content: "Done", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done",
input: llm.CompletionResponse{Content: "<|start|>assistant<|message|>Done<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Done", // After message end, state is reset to normal
},
},
},
@@ -167,7 +181,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "simple assistant after analysis",
steps: []step{
{
input: llm.CompletionResponse{Thinking: "Think", Content: "Answer", Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|channel|>analysis<|message|>Think<|end|><|start|>assistant<|message|>Answer<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "Answer",
wantThinking: "Think",
},
@@ -177,7 +191,7 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call parsed and returned correctly",
steps: []step{
{
input: llm.CompletionResponse{Content: "The weather is sunny", ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "get_weather", Arguments: api.ToolCallFunctionArguments{"location": "San Francisco"}}}}, Done: true, DoneReason: llm.DoneReasonStop},
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.get_weather<|message|>{\"location\":\"San Francisco\"}<|end|><|start|>assistant<|message|>The weather is sunny<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
wantContent: "The weather is sunny",
wantToolCalls: []api.ToolCall{
{
@@ -196,10 +210,15 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
name: "tool call with streaming JSON across chunks",
steps: []step{
{
input: llm.CompletionResponse{Done: false},
input: llm.CompletionResponse{Content: "<|channel|>commentary to=functions.calculate<|message|>{\"expr", Done: false},
// No output yet - incomplete JSON
},
{
input: llm.CompletionResponse{ToolCalls: []api.ToolCall{{Function: api.ToolCallFunction{Name: "calculate", Arguments: api.ToolCallFunctionArguments{"expression": "2+2"}}}}, Done: true},
input: llm.CompletionResponse{Content: "ession\":\"2+", Done: false},
// Still no output - incomplete JSON
},
{
input: llm.CompletionResponse{Content: "2\"}", Done: true},
wantToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
@@ -381,9 +400,9 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
gin.SetMode(gin.TestMode)
mockResponses := []llm.CompletionResponse{
{Content: "First ", Done: false},
{Content: "<|message|>First ", Done: false},
{Content: "chunk ", Done: false},
{Content: "here", Done: true, DoneReason: llm.DoneReasonStop},
{Content: "here<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
}
mock := mockRunner{
@@ -488,3 +507,189 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
t.Errorf("expected at least 2 content chunks for streaming, got %d", contentChunks)
}
}
func TestChatHarmonyParserStreaming(t *testing.T) {
gin.SetMode(gin.TestMode)
type expectedChunk struct {
afterResponse int // Which mock response this chunk should appear after
content string // Expected content in this chunk
thinking string // Expected thinking in this chunk
}
testCases := []struct {
name string
mockResponses []llm.CompletionResponse
expectedChunks []expectedChunk
wantContent string
wantThinking string
}{
{
name: "simple message without thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|start|>assistant<|message|>Hello, ", Done: false},
{Content: "how can I help?", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 1, content: "Hello, "},
{afterResponse: 2, content: "how can I help?"},
},
wantContent: "Hello, how can I help?",
},
{
name: "message with analysis channel for thinking",
mockResponses: []llm.CompletionResponse{
{Content: "<|channel|>analysis<|message|>", Done: false},
{Content: "Let me think ", Done: false},
{Content: "about this problem...", Done: false},
{Content: "<|end|>", Done: false},
{Content: "<|start|>assistant<|message|>", Done: false},
{Content: "The answer ", Done: false},
{Content: "is 42", Done: false},
{Content: "<|end|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 2, thinking: "Let me think "},
{afterResponse: 3, thinking: "about this problem..."},
{afterResponse: 6, content: "The answer "},
{afterResponse: 7, content: "is 42"},
},
wantContent: "The answer is 42",
wantThinking: "Let me think about this problem...",
},
{
name: "streaming with partial tags across boundaries",
mockResponses: []llm.CompletionResponse{
{Content: "<|chan", Done: false},
{Content: "nel|>analy", Done: false},
{Content: "sis<|mess", Done: false},
{Content: "age|>Think", Done: false},
{Content: "ing deeply...<|end|>", Done: false},
{Content: "<|start|>assi", Done: false},
{Content: "stant<|message|>Result ", Done: false},
{Content: "computed<|e", Done: false},
{Content: "nd|>", Done: true, DoneReason: llm.DoneReasonStop},
},
expectedChunks: []expectedChunk{
{afterResponse: 4, thinking: "Think"},
{afterResponse: 5, thinking: "ing deeply..."},
{afterResponse: 7, content: "Result "},
{afterResponse: 8, content: "computed"},
},
wantContent: "Result computed",
wantThinking: "Thinking deeply...",
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Channel to synchronize mock responses with chunk verification
responsesSent := make(chan int, len(tc.mockResponses))
mock := mockRunner{
CompletionFn: func(ctx context.Context, r llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
// Send mock responses one at a time, notifying when each is sent
for i, resp := range tc.mockResponses {
fn(resp)
responsesSent <- i + 1
}
close(responsesSent)
return nil
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a minimal model
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("failed to create model: %d", w.Code)
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: &streamTrue,
Tools: getTestTools(),
})
if w.Code != http.StatusOK {
t.Fatalf("chat request failed: %d - %s", w.Code, w.Body.String())
}
// Parse streaming response
var chunks []api.ChatResponse
var content, thinking strings.Builder
decoder := json.NewDecoder(w.Body)
for decoder.More() {
var chunk api.ChatResponse
if err := decoder.Decode(&chunk); err != nil {
t.Fatalf("failed to decode chunk: %v", err)
}
chunks = append(chunks, chunk)
// Accumulate content and thinking from each chunk
content.WriteString(chunk.Message.Content)
thinking.WriteString(chunk.Message.Thinking)
// Debug output
t.Logf("Chunk %d: content=%q thinking=%q done=%v", len(chunks), chunk.Message.Content, chunk.Message.Thinking, chunk.Done)
}
// Verify we got streaming chunks
if len(chunks) == 0 {
t.Fatal("expected streaming chunks, got none")
}
gotContent := content.String()
gotThinking := thinking.String()
if gotContent != tc.wantContent {
t.Errorf("content mismatch: got %q, want %q", gotContent, tc.wantContent)
}
if gotThinking != tc.wantThinking {
t.Errorf("thinking mismatch: got %q, want %q", gotThinking, tc.wantThinking)
}
// Verify last chunk has done=true
lastChunk := chunks[len(chunks)-1]
if !lastChunk.Done {
t.Error("expected last chunk to have done=true")
}
})
}
}