Compare commits
10 Commits
v0.13.1
...
parth/olmo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c3bf414ef | ||
|
|
0a9862a383 | ||
|
|
f475cc365a | ||
|
|
dd3306d3a0 | ||
|
|
57c1d7db9a | ||
|
|
91d6370a62 | ||
|
|
38a2a6468f | ||
|
|
064ec63ddf | ||
|
|
fd959fbf7a | ||
|
|
cfc9729edf |
2
.gitattributes
vendored
2
.gitattributes
vendored
@@ -19,8 +19,6 @@ ml/backend/**/*.comp linguist-vendored
|
||||
ml/backend/**/*.glsl linguist-vendored
|
||||
ml/backend/**/CMakeLists.txt linguist-vendored
|
||||
|
||||
app/webview linguist-vendored
|
||||
|
||||
llama/build-info.cpp linguist-generated
|
||||
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ linters:
|
||||
- errorlint
|
||||
- exptostd
|
||||
- gocheckcompilerdirectives
|
||||
- gocritic
|
||||
- govet
|
||||
- ineffassign
|
||||
- intrange
|
||||
|
||||
@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
|
||||
bts := scanner.Bytes()
|
||||
if err := json.Unmarshal(bts, &errorResponse); err != nil {
|
||||
if response.StatusCode >= http.StatusBadRequest {
|
||||
return StatusError{
|
||||
StatusCode: response.StatusCode,
|
||||
Status: response.Status,
|
||||
ErrorMessage: string(bts),
|
||||
}
|
||||
}
|
||||
return errors.New(string(bts))
|
||||
return fmt.Errorf("unmarshal: %w", err)
|
||||
}
|
||||
|
||||
if response.StatusCode == http.StatusUnauthorized {
|
||||
|
||||
@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
type testError struct {
|
||||
message string
|
||||
statusCode int
|
||||
raw bool // if true, write message as-is instead of JSON encoding
|
||||
}
|
||||
|
||||
func (e testError) Error() string {
|
||||
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
responses: []any{
|
||||
"internal server error",
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
responses: []any{
|
||||
"<html><body>404 Not Found</body></html>",
|
||||
},
|
||||
wantErr: "404 Not Found",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := resp.(string); ok {
|
||||
fmt.Fprintln(w, str)
|
||||
flusher.Flush()
|
||||
continue
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
||||
t.Fatalf("failed to encode response: %v", err)
|
||||
}
|
||||
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
|
||||
|
||||
func TestClientDo(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
wantStatusCode int
|
||||
name string
|
||||
response any
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "immediate error response",
|
||||
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
|
||||
message: "test error message",
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
wantErr: "test error message",
|
||||
wantStatusCode: http.StatusBadRequest,
|
||||
wantErr: "test error message",
|
||||
},
|
||||
{
|
||||
name: "server error response",
|
||||
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
|
||||
message: "internal error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
wantErr: "internal error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
wantErr: "internal error",
|
||||
},
|
||||
{
|
||||
name: "successful response",
|
||||
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
|
||||
Success: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "plain text error response",
|
||||
response: testError{
|
||||
message: "internal server error",
|
||||
statusCode: http.StatusInternalServerError,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "internal server error",
|
||||
wantStatusCode: http.StatusInternalServerError,
|
||||
},
|
||||
{
|
||||
name: "HTML error page",
|
||||
response: testError{
|
||||
message: "<html><body>404 Not Found</body></html>",
|
||||
statusCode: http.StatusNotFound,
|
||||
raw: true,
|
||||
},
|
||||
wantErr: "<html><body>404 Not Found</body></html>",
|
||||
wantStatusCode: http.StatusNotFound,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if errResp, ok := tc.response.(testError); ok {
|
||||
w.WriteHeader(errResp.statusCode)
|
||||
if !errResp.raw {
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
} else {
|
||||
// Write raw message (simulates non-JSON error responses)
|
||||
fmt.Fprint(w, errResp.message)
|
||||
err := json.NewEncoder(w).Encode(map[string]string{
|
||||
"error": errResp.message,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal("failed to encode error response:", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
|
||||
if err.Error() != tc.wantErr {
|
||||
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
|
||||
}
|
||||
if tc.wantStatusCode != 0 {
|
||||
if statusErr, ok := err.(StatusError); ok {
|
||||
if statusErr.StatusCode != tc.wantStatusCode {
|
||||
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
|
||||
}
|
||||
} else {
|
||||
t.Errorf("expected StatusError, got %T", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
latest.Summary()
|
||||
}
|
||||
|
||||
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
|
||||
return &api.Message{Role: role, Content: fullResponse.String()}, nil
|
||||
}
|
||||
|
||||
func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
|
||||
148
cmd/testolmo/main.go
Normal file
148
cmd/testolmo/main.go
Normal file
@@ -0,0 +1,148 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
_ "github.com/ollama/ollama/model/models" // Register all models
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/sample"
|
||||
)
|
||||
|
||||
func main() {
|
||||
modelPath := "/Users/parth/.ollama/models/blobs/sha256-a87e10578f328b087f888ac7bd1018555e26028a1130980f20312b4de3a10d70"
|
||||
|
||||
fmt.Println("Loading OLMo model...")
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Model loaded successfully!")
|
||||
|
||||
// Initialize the cache
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
// Initialize with reasonable defaults:
|
||||
// - dtype: F16
|
||||
// - maxSequences: 1 (single sequence)
|
||||
// - capacity: 2048 (context length)
|
||||
// - maxBatch: 512
|
||||
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
|
||||
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
|
||||
}
|
||||
|
||||
// Use the olmo3 renderer to format the prompt properly
|
||||
messages := []api.Message{
|
||||
{Role: "user", Content: "wagwan"},
|
||||
}
|
||||
// prompt := "Question: What is machine learning? Answer:"
|
||||
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
// prompt = prompt[:len(prompt)]
|
||||
// prompt := "Question: What is machine learning? Answer:"
|
||||
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, false)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
|
||||
|
||||
// Generate 20 tokens
|
||||
maxTokens := 20
|
||||
generated := make([]int32, 0, maxTokens)
|
||||
|
||||
// Create sampler (temperature=0 for greedy sampling)
|
||||
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
||||
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
// Create a new context for each generation step to avoid memory buildup
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
var inputTokens []int32
|
||||
var positions []int32
|
||||
|
||||
if i == 0 {
|
||||
// First iteration: process all prompt tokens
|
||||
inputTokens = tokens
|
||||
positions = make([]int32, len(tokens))
|
||||
for j := range positions {
|
||||
positions[j] = int32(j)
|
||||
}
|
||||
} else {
|
||||
// Subsequent iterations: only process the newly generated token
|
||||
// The last token is at position len(tokens)-1 (its index in the sequence)
|
||||
inputTokens = []int32{tokens[len(tokens)-1]}
|
||||
positions = []int32{int32(len(tokens) - 1)}
|
||||
}
|
||||
|
||||
sequences := make([]int, len(inputTokens))
|
||||
// All tokens belong to sequence 0
|
||||
|
||||
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
|
||||
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: inputsTensor,
|
||||
Positions: positions,
|
||||
Sequences: sequences,
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
// Forward pass (model.Forward handles cache.StartForward internally)
|
||||
logits, err := model.Forward(ctx, m, batch)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
logits = logits.Contiguous(ctx)
|
||||
ctx.Forward(logits).Compute(logits)
|
||||
|
||||
logitValues := logits.Floats()
|
||||
|
||||
// Sample next token
|
||||
nextToken, err := sampler.Sample(logitValues)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Close context before moving to next iteration
|
||||
ctx.Close()
|
||||
|
||||
generated = append(generated, nextToken)
|
||||
tokens = append(tokens, nextToken)
|
||||
|
||||
// Decode and print
|
||||
decoded, _ := tp.Decode([]int32{nextToken})
|
||||
fmt.Print(decoded)
|
||||
|
||||
// Stop on EOS or <|im_end|>
|
||||
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
|
||||
break
|
||||
}
|
||||
// Check if we generated <|im_end|> (stop token for chat)
|
||||
if decoded == "<|im_end|>" {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n\n✅ Generation completed!")
|
||||
fullText, _ := tp.Decode(generated)
|
||||
fmt.Printf("Generated: %s\n", fullText)
|
||||
}
|
||||
@@ -200,6 +200,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
||||
conv = &qwen25VLModel{}
|
||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||
conv = &qwen3VLModel{}
|
||||
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
|
||||
conv = &olmoModel{}
|
||||
case "BertModel":
|
||||
conv = &bertModel{}
|
||||
case "CohereForCausalLM":
|
||||
|
||||
@@ -29,15 +29,6 @@ type mistral3Model struct {
|
||||
SlidingWindow *uint32 `json:"sliding_window"`
|
||||
HiddenAct string `json:"hidden_act"`
|
||||
VocabSize uint32 `json:"vocab_size"`
|
||||
RopeParameters struct {
|
||||
BetaFast float32 `json:"beta_fast"`
|
||||
BetaSlow float32 `json:"beta_slow"`
|
||||
Factor float32 `json:"factor"`
|
||||
ScalingBeta float32 `json:"llama_4_scaling_beta"`
|
||||
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
|
||||
RopeType string `json:"rope_type"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
} `json:"rope_parameters"`
|
||||
} `json:"text_config"`
|
||||
VisionModel struct {
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
@@ -70,13 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
|
||||
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
|
||||
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
|
||||
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
|
||||
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
|
||||
|
||||
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
|
||||
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
|
||||
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
|
||||
}
|
||||
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
|
||||
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
|
||||
|
||||
// Vision configuration
|
||||
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers
|
||||
|
||||
94
convert/convert_olmo.go
Normal file
94
convert/convert_olmo.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
type olmoModel struct {
|
||||
ModelParameters
|
||||
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
ClampKQV float32 `json:"f_clamp_kqv"`
|
||||
SlidingWindow uint32 `json:"sliding_window"`
|
||||
LayerTypes []string `json:"layer_types"`
|
||||
}
|
||||
|
||||
var _ ModelConverter = (*olmoModel)(nil)
|
||||
|
||||
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := p.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "olmo"
|
||||
kv["olmo.block_count"] = p.NumHiddenLayers
|
||||
kv["olmo.context_length"] = p.MaxPositionEmbeddings
|
||||
kv["olmo.embedding_length"] = p.HiddenSize
|
||||
kv["olmo.feed_forward_length"] = p.IntermediateSize
|
||||
kv["olmo.attention.head_count"] = p.NumAttentionHeads
|
||||
kv["olmo.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||
|
||||
if p.RopeTheta > 0 {
|
||||
kv["olmo.rope.freq_base"] = p.RopeTheta
|
||||
} else {
|
||||
kv["olmo.rope.freq_base"] = float32(10000.0)
|
||||
}
|
||||
|
||||
if p.RMSNormEPS > 0 {
|
||||
kv["olmo.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||
}
|
||||
|
||||
if p.ClampKQV > 0 {
|
||||
kv["olmo.attention.clamp_kqv"] = p.ClampKQV
|
||||
}
|
||||
|
||||
if p.SlidingWindow > 0 {
|
||||
kv["olmo.attention.sliding_window"] = p.SlidingWindow
|
||||
}
|
||||
|
||||
if len(p.LayerTypes) > 0 {
|
||||
kv["olmo.attention.layer_types"] = p.LayerTypes
|
||||
}
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
out := make([]*ggml.Tensor, 0, len(ts))
|
||||
for _, t := range ts {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
|
||||
func (p *olmoModel) Replacements() []string {
|
||||
return []string{
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"model.norm", "output_norm",
|
||||
"self_attn.q_proj", "attn_q",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
"self_attn.o_proj", "attn_output",
|
||||
"self_attn.q_norm", "attn_q_norm",
|
||||
"self_attn.k_norm", "attn_k_norm",
|
||||
"post_attention_layernorm", "post_attention_norm",
|
||||
"post_feedforward_layernorm", "post_ffw_norm",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
}
|
||||
}
|
||||
@@ -65,7 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
detectIncompatibleLibraries()
|
||||
|
||||
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
|
||||
overrideWarnings()
|
||||
@@ -99,9 +98,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
|
||||
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
|
||||
continue
|
||||
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
|
||||
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
|
||||
continue
|
||||
@@ -488,16 +484,3 @@ func overrideWarnings() {
|
||||
slog.Warn("if GPUs are not correctly discovered, unset and try again")
|
||||
}
|
||||
}
|
||||
|
||||
func detectIncompatibleLibraries() {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
basePath, err := exec.LookPath("ggml-base.dll")
|
||||
if err != nil || basePath == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
|
||||
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
|
||||
}
|
||||
}
|
||||
|
||||
11
docs/faq.mdx
11
docs/faq.mdx
@@ -57,13 +57,8 @@ ollama ps
|
||||
```
|
||||
|
||||
<Info>
|
||||
|
||||
**Output**:
|
||||
|
||||
```
|
||||
NAME ID SIZE PROCESSOR UNTIL
|
||||
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
|
||||
```
|
||||
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
|
||||
100% GPU 4 minutes from now ```
|
||||
</Info>
|
||||
|
||||
The `Processor` column will show which memory the model was loaded in to:
|
||||
@@ -390,4 +385,4 @@ Ollama for Windows and macOS register as a login item during installation. You
|
||||
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
|
||||
|
||||
**MacOS**
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.
|
||||
@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
|
||||
| Parameter | Description | Value Type | Example Usage |
|
||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
|
||||
@@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"deepseekocr",
|
||||
"deepseek2",
|
||||
"nomic-bert",
|
||||
"olmo2",
|
||||
}, kv.Architecture())
|
||||
}
|
||||
|
||||
|
||||
@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
|
||||
// Qwen 3 VL mixture of experts
|
||||
model: "qwen3-vl:30b",
|
||||
},
|
||||
{
|
||||
model: "ministral-3",
|
||||
},
|
||||
}
|
||||
|
||||
for _, v := range testCases {
|
||||
|
||||
@@ -38,7 +38,6 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"ministral-3",
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
@@ -168,7 +167,6 @@ var (
|
||||
"medllama2",
|
||||
"megadolphin",
|
||||
"minicpm-v",
|
||||
"ministral-3",
|
||||
"mistral-large",
|
||||
"mistral-nemo",
|
||||
"mistral-openorca",
|
||||
@@ -272,7 +270,6 @@ var (
|
||||
"mistral",
|
||||
"qwen2.5",
|
||||
"qwen2",
|
||||
"ministral-3",
|
||||
"mistral-nemo",
|
||||
"mistral-small",
|
||||
"mixtral:8x22b",
|
||||
|
||||
@@ -874,7 +874,7 @@ func (s *llmServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.Devic
|
||||
}}
|
||||
}
|
||||
gpuLayers, layers := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||
err := s.verifyLayout(systemInfo, systemGPUs, memory, requireFull, gpuLayers, layers)
|
||||
err := s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -943,7 +943,7 @@ func (s *llmServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMe
|
||||
}
|
||||
|
||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
func (s *llmServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||
var vramSize uint64
|
||||
@@ -970,8 +970,8 @@ nextLayer:
|
||||
}
|
||||
|
||||
if requireFull {
|
||||
if len(systemGPUs) > 0 && gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
slog.Info("model requires more gpu memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "loaded layers", gpuLayers.Sum())
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
|
||||
@@ -998,7 +998,7 @@ nextLayer:
|
||||
}
|
||||
}
|
||||
|
||||
if len(systemGPUs) > 0 && gpuLayers.Sum() == 0 {
|
||||
if gpuLayers.Sum() == 0 {
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
}
|
||||
|
||||
|
||||
@@ -26,11 +26,10 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "No GPU",
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{},
|
||||
requireFull: true, // Should not try to evict even though we can't load any layers
|
||||
name: "No GPU",
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{},
|
||||
},
|
||||
{
|
||||
name: "Full single GPU",
|
||||
|
||||
@@ -509,9 +509,11 @@ func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
|
||||
// to crash at inference time and requires deeper validation before we include
|
||||
// it in the supported devices list.
|
||||
func (d DeviceInfo) NeedsInitValidation() bool {
|
||||
// ROCm: rocblas will crash on unsupported devices.
|
||||
// CUDA: verify CC is supported by the version of the library
|
||||
return d.Library == "ROCm" || d.Library == "CUDA"
|
||||
// At this time the only library we know needs a 2nd pass is ROCm since
|
||||
// rocblas will crash on unsupported devices. We want to find those crashes
|
||||
// during bootstrap discovery so we can eliminate those GPUs before the user
|
||||
// tries to run inference on them
|
||||
return d.Library == "ROCm"
|
||||
}
|
||||
|
||||
// Set the init validation environment variable
|
||||
|
||||
@@ -159,9 +159,8 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
positionsScale := m.getScale(ctx, batch.Positions)
|
||||
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, positionsScale, batch.Outputs, batch, m.Cache), nil
|
||||
return m.TextModel.Forward(ctx, batch.Inputs, positions, batch.Outputs, batch, m.Cache), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
|
||||
@@ -16,8 +16,6 @@ type TextOptions struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
headDim, ropeDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeOrigPosEmbeddings int
|
||||
ropeScalingBeta float32
|
||||
}
|
||||
|
||||
type TextModel struct {
|
||||
@@ -36,7 +34,7 @@ type SelfAttention struct {
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
|
||||
@@ -51,10 +49,6 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs, posit
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
if opts.ropeOrigPosEmbeddings > 0 {
|
||||
q = q.Mul(ctx, positionsScale)
|
||||
}
|
||||
|
||||
kqv := nn.Attention(ctx, q, k, v, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
kqv = kqv.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
@@ -82,11 +76,11 @@ type Layer struct {
|
||||
MLP *MLP
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, positionsScale, cache, opts)
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
|
||||
|
||||
// In the final layer (outputs != nil), optimize by pruning to just the token positions
|
||||
// we need logits for.
|
||||
@@ -103,7 +97,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, positionsScale
|
||||
return hiddenState.Add(ctx, residual)
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) ml.Tensor {
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs).Duplicate(ctx)
|
||||
|
||||
// image embeddings
|
||||
@@ -120,36 +114,25 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, positionsScale, o
|
||||
lastLayerOutputs = outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, positionsScale, lastLayerOutputs, cache, m.TextOptions)
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, cache, m.TextOptions)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
func (m *TextModel) getScale(ctx ml.Context, positions []int32) ml.Tensor {
|
||||
posScale := make([]float32, len(positions))
|
||||
for n, pos := range positions {
|
||||
interval := math.Floor(float64(pos) / float64(m.ropeOrigPosEmbeddings))
|
||||
posScale[n] = float32(1.0 + float64(m.ropeScalingBeta)*math.Log(1.0+interval))
|
||||
}
|
||||
return ctx.Input().FromFloats(posScale, 1, 1, len(posScale))
|
||||
}
|
||||
|
||||
func newTextModel(c fs.Config) *TextModel {
|
||||
return &TextModel{
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
TextOptions: &TextOptions{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
ropeOrigPosEmbeddings: int(c.Uint("rope.scaling.original_context_length")),
|
||||
ropeScalingBeta: c.Float("rope.scaling_beta"),
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||
_ "github.com/ollama/ollama/model/models/mllama"
|
||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||
_ "github.com/ollama/ollama/model/models/olmo"
|
||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||
|
||||
271
model/models/olmo/model.go
Normal file
271
model/models/olmo/model.go
Normal file
@@ -0,0 +1,271 @@
|
||||
package olmo
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/fast"
|
||||
"github.com/ollama/ollama/ml/nn/rope"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
headDim, ropeDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
clampKQV float32
|
||||
|
||||
originalContextLength int
|
||||
attnFactor float32
|
||||
slidingWindow int32
|
||||
slidingWindowPattern []bool // per-layer SWA pattern (true = SWA, false = full attention)
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
model.Base
|
||||
model.TextProcessor
|
||||
|
||||
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||
Layers []Layer `gguf:"blk"`
|
||||
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||
|
||||
Options
|
||||
}
|
||||
|
||||
func New(c fs.Config) (model.Model, error) {
|
||||
vocabulary := model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
EOS: append(
|
||||
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||
),
|
||||
}
|
||||
|
||||
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||
return nil, model.ErrUnsupportedTokenizer
|
||||
}
|
||||
|
||||
var pretokenizers []string
|
||||
if c.String("tokenizer.ggml.pre") != "default" {
|
||||
pretokenizers = []string{
|
||||
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
}
|
||||
}
|
||||
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||
|
||||
slidingWindow := int32(c.Uint("attention.sliding_window"))
|
||||
slidingWindowPattern := c.Bools("attention.sliding_window_pattern")
|
||||
|
||||
m := Model{
|
||||
TextProcessor: processor,
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
Options: Options{
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base", 1e4),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||
clampKQV: c.Float("attention.clamp_kqv", 0),
|
||||
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||
attnFactor: c.Float("rope.scaling.attn_factor", 1),
|
||||
slidingWindow: slidingWindow,
|
||||
slidingWindowPattern: slidingWindowPattern,
|
||||
},
|
||||
}
|
||||
|
||||
// OLMo3 uses interleaved sliding window attention (every 4th layer is full attention)
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(slidingWindow, m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
type SelfAttention struct {
|
||||
Query *nn.Linear `gguf:"attn_q"`
|
||||
Key *nn.Linear `gguf:"attn_k"`
|
||||
Value *nn.Linear `gguf:"attn_v"`
|
||||
Output *nn.Linear `gguf:"attn_output"`
|
||||
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||
}
|
||||
|
||||
func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Options) {
|
||||
opts := []func(*rope.Options){
|
||||
rope.WithFactors(factors),
|
||||
}
|
||||
|
||||
if o.originalContextLength > 0 {
|
||||
if isSWA {
|
||||
// For SWA layers, use regular rope with no YaRN scaling
|
||||
// ext_factor=0.0, attn_factor=1.0 per llama.cpp
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(0.),
|
||||
rope.WithAttentionFactor(1.),
|
||||
)
|
||||
} else {
|
||||
// For full attention layers, use YaRN scaling
|
||||
opts = append(opts,
|
||||
rope.WithOriginalContextLength(o.originalContextLength),
|
||||
rope.WithExtrapolationFactor(1.),
|
||||
rope.WithAttentionFactor(o.attnFactor),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
if sa.QNorm != nil {
|
||||
query = sa.QNorm.Forward(ctx, query, opts.eps)
|
||||
}
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
if sa.KNorm != nil {
|
||||
key = sa.KNorm.Forward(ctx, key, opts.eps)
|
||||
}
|
||||
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
freqScale := float32(1.0)
|
||||
if !isSWA {
|
||||
freqScale = 1. / opts.ropeScale
|
||||
}
|
||||
|
||||
ropeOpts := opts.ropeOptions(sa.RopeFactors, isSWA)
|
||||
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
|
||||
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
|
||||
|
||||
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||
isSWA := m.isSWALayer(layer)
|
||||
|
||||
freqScale := float32(1.0)
|
||||
if !isSWA {
|
||||
freqScale = 1. / m.ropeScale
|
||||
}
|
||||
|
||||
ropeOpts := m.Options.ropeOptions(m.Layers[layer].SelfAttention.RopeFactors, isSWA)
|
||||
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, freqScale, ropeOpts...), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
Up *nn.Linear `gguf:"ffn_up"`
|
||||
Down *nn.Linear `gguf:"ffn_down"`
|
||||
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||
}
|
||||
|
||||
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||
return mlp.Down.Forward(ctx, hiddenState)
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
SelfAttention *SelfAttention
|
||||
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||
MLP *MLP
|
||||
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||
}
|
||||
|
||||
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
|
||||
residual := hiddenState
|
||||
|
||||
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts, isSWA)
|
||||
|
||||
if outputs != nil {
|
||||
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||
residual = residual.Rows(ctx, outputs)
|
||||
}
|
||||
|
||||
if l.PostAttentionNorm != nil {
|
||||
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
}
|
||||
|
||||
ffnInput := hiddenState.Add(ctx, residual)
|
||||
|
||||
hiddenState = l.MLP.Forward(ctx, ffnInput, opts)
|
||||
|
||||
if l.PostFFWNorm != nil {
|
||||
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps)
|
||||
}
|
||||
|
||||
return hiddenState.Add(ctx, ffnInput)
|
||||
}
|
||||
|
||||
// isSWALayer returns true if the layer uses sliding window attention.
|
||||
// Uses the sliding_window_pattern from the model config if available,
|
||||
// otherwise falls back to the default OLMo3 pattern (every 4th layer is full attention).
|
||||
func (m *Model) isSWALayer(layerIdx int) bool {
|
||||
if len(m.slidingWindowPattern) > layerIdx {
|
||||
return m.slidingWindowPattern[layerIdx]
|
||||
}
|
||||
// Fallback: OLMo3 pattern where every 4th layer (indices 3, 7, 11, ...) uses full attention
|
||||
return (layerIdx+1)%4 != 0
|
||||
}
|
||||
|
||||
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||
|
||||
for i, layer := range m.Layers {
|
||||
m.Cache.SetLayer(i)
|
||||
|
||||
isSWA := m.isSWALayer(i)
|
||||
|
||||
// Set cache type for interleaved SWA (OLMo3)
|
||||
if wc, ok := m.Cache.(*kvcache.WrapperCache); ok {
|
||||
if isSWA {
|
||||
wc.SetLayerType(0) // SWA cache
|
||||
} else {
|
||||
wc.SetLayerType(1) // Causal cache
|
||||
}
|
||||
}
|
||||
|
||||
var outputs ml.Tensor
|
||||
if i == len(m.Layers)-1 {
|
||||
outputs = batch.Outputs
|
||||
}
|
||||
|
||||
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options, isSWA)
|
||||
}
|
||||
|
||||
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||
return m.Output.Forward(ctx, hiddenState), nil
|
||||
}
|
||||
|
||||
func init() {
|
||||
model.Register("olmo2", New)
|
||||
}
|
||||
132
model/models/olmo/testolmo.go
Normal file
132
model/models/olmo/testolmo.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package olmo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
"github.com/ollama/ollama/sample"
|
||||
)
|
||||
|
||||
func main() {
|
||||
modelPath := "/Users/nicole/models/Olmo-3-7B-Think/olmo-3-7b-think-q8_0.gguf"
|
||||
|
||||
fmt.Println("Loading OLMo model...")
|
||||
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println("✅ Model loaded successfully!")
|
||||
|
||||
// Initialize the cache
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
// Initialize with reasonable defaults:
|
||||
// - dtype: F16
|
||||
// - maxSequences: 1 (single sequence)
|
||||
// - capacity: 2048 (context length)
|
||||
// - maxBatch: 512
|
||||
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
|
||||
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
|
||||
}
|
||||
|
||||
// Test generation
|
||||
prompt := "Question: What is machine learning? Answer:"
|
||||
fmt.Printf("\nPrompt: %s\n", prompt)
|
||||
|
||||
tp := m.(model.TextProcessor)
|
||||
tokens, err := tp.Encode(prompt, true)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
|
||||
|
||||
// Generate 20 tokens
|
||||
maxTokens := 20
|
||||
generated := make([]int32, 0, maxTokens)
|
||||
|
||||
// Create sampler (temperature=0 for greedy sampling)
|
||||
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
||||
|
||||
for i := 0; i < maxTokens; i++ {
|
||||
// Create a new context for each generation step to avoid memory buildup
|
||||
ctx := m.Backend().NewContext()
|
||||
|
||||
var inputTokens []int32
|
||||
var positions []int32
|
||||
|
||||
if i == 0 {
|
||||
// First iteration: process all prompt tokens
|
||||
inputTokens = tokens
|
||||
positions = make([]int32, len(tokens))
|
||||
for j := range positions {
|
||||
positions[j] = int32(j)
|
||||
}
|
||||
} else {
|
||||
// Subsequent iterations: only process the newly generated token
|
||||
// The last token is at position len(tokens)-1 (its index in the sequence)
|
||||
inputTokens = []int32{tokens[len(tokens)-1]}
|
||||
positions = []int32{int32(len(tokens) - 1)}
|
||||
}
|
||||
|
||||
sequences := make([]int, len(inputTokens))
|
||||
// All tokens belong to sequence 0
|
||||
|
||||
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
|
||||
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
|
||||
|
||||
batch := input.Batch{
|
||||
Inputs: inputsTensor,
|
||||
Positions: positions,
|
||||
Sequences: sequences,
|
||||
Outputs: outputs,
|
||||
}
|
||||
|
||||
// Forward pass (model.Forward handles cache.StartForward internally)
|
||||
logits, err := model.Forward(ctx, m, batch)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
logits = logits.Contiguous(ctx)
|
||||
ctx.Forward(logits).Compute(logits)
|
||||
|
||||
logitValues := logits.Floats()
|
||||
|
||||
// Sample next token
|
||||
nextToken, err := sampler.Sample(logitValues)
|
||||
if err != nil {
|
||||
ctx.Close()
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// Close context before moving to next iteration
|
||||
ctx.Close()
|
||||
|
||||
generated = append(generated, nextToken)
|
||||
tokens = append(tokens, nextToken)
|
||||
|
||||
// Decode and print
|
||||
decoded, _ := tp.Decode([]int32{nextToken})
|
||||
fmt.Print(decoded)
|
||||
|
||||
// Stop on EOS
|
||||
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("\n\n✅ Generation completed!")
|
||||
fullText, _ := tp.Decode(generated)
|
||||
fmt.Printf("Generated: %s\n", fullText)
|
||||
}
|
||||
@@ -1,136 +0,0 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
type ministralParserState int
|
||||
|
||||
const (
|
||||
ministralCollectingContent = iota
|
||||
ministralCollectingThinkingContent
|
||||
ministralCollectingToolName
|
||||
ministralCollectingToolArgs
|
||||
)
|
||||
|
||||
type MinistralParser struct {
|
||||
state ministralParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
hasThinkingSupport bool
|
||||
currentTool *api.Tool
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *MinistralParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *MinistralParser) setInitialState(lastMessage *api.Message) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
if !p.HasThinkingSupport() {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = ministralCollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = ministralCollectingThinkingContent
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.tools = tools
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
}
|
||||
|
||||
func toolByName(tools []api.Tool, n string) (*api.Tool, error) {
|
||||
for i := range tools {
|
||||
if tools[i].Function.Name == n {
|
||||
return &tools[i], nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("tool '%s' not found", n)
|
||||
}
|
||||
|
||||
func (p *MinistralParser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
switch p.state {
|
||||
case ministralCollectingContent:
|
||||
if strings.Contains(p.buffer.String(), "[TOOL_CALLS]") {
|
||||
before, _ := splitAtTag(&p.buffer, "[TOOL_CALLS]", false)
|
||||
if before != "" {
|
||||
return before, "", calls, nil
|
||||
}
|
||||
p.state = ministralCollectingToolName
|
||||
} else if strings.Contains(p.buffer.String(), "[THINK]") {
|
||||
p.state = ministralCollectingThinkingContent
|
||||
return "", "", calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return s, "", calls, nil
|
||||
}
|
||||
case ministralCollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), "[/THINK]") {
|
||||
thinkingContent, after := splitAtTag(&p.buffer, "[/THINK]", true)
|
||||
p.state = ministralCollectingContent
|
||||
if after != "" {
|
||||
p.buffer.Reset()
|
||||
return after, thinkingContent, calls, nil
|
||||
}
|
||||
return "", thinkingContent, calls, nil
|
||||
} else {
|
||||
p.buffer.Reset()
|
||||
return "", s, calls, nil
|
||||
}
|
||||
case ministralCollectingToolName:
|
||||
if strings.Contains(p.buffer.String(), "[ARGS]") {
|
||||
name, _ := splitAtTag(&p.buffer, "[ARGS]", false)
|
||||
|
||||
t, err := toolByName(p.tools, name)
|
||||
if err != nil {
|
||||
return "", "", calls, err
|
||||
}
|
||||
p.currentTool = t
|
||||
p.state = ministralCollectingToolArgs
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
case ministralCollectingToolArgs:
|
||||
if strings.Contains(p.buffer.String(), "}") {
|
||||
before, _ := splitAtTag(&p.buffer, "}", false)
|
||||
before += "}"
|
||||
|
||||
var data map[string]any
|
||||
if err := json.Unmarshal([]byte(before), &data); err != nil {
|
||||
// todo - throw a better error
|
||||
return "", "", calls, err
|
||||
}
|
||||
|
||||
p.state = ministralCollectingContent
|
||||
|
||||
call := api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: p.currentTool.Function.Name,
|
||||
Arguments: api.ToolCallFunctionArguments(data),
|
||||
},
|
||||
}
|
||||
calls = append(calls, call)
|
||||
return "", "", calls, nil
|
||||
}
|
||||
return "", "", calls, nil
|
||||
}
|
||||
|
||||
return p.buffer.String(), thinking, calls, nil
|
||||
}
|
||||
469
model/parsers/olmo3.go
Normal file
469
model/parsers/olmo3.go
Normal file
@@ -0,0 +1,469 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
type olmo3ParserState int
|
||||
|
||||
const (
|
||||
olmo3StateContent olmo3ParserState = iota
|
||||
olmo3StateToolCalls
|
||||
olmo3StateToolCallsDone
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3FuncCallsOpenTag = "<function_calls>"
|
||||
olmo3FuncCallsCloseTag = "</function_calls>"
|
||||
)
|
||||
|
||||
type Olmo3Parser struct {
|
||||
state olmo3ParserState
|
||||
buffer strings.Builder
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||
p.state = olmo3StateContent
|
||||
return tools
|
||||
}
|
||||
|
||||
type olmo3ParserEvent interface {
|
||||
isOlmo3ParserEvent()
|
||||
}
|
||||
|
||||
type olmo3ParserEventContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
type olmo3ParserEventToolCalls struct {
|
||||
calls []api.ToolCall
|
||||
}
|
||||
|
||||
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
|
||||
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
|
||||
|
||||
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
|
||||
if done {
|
||||
// Drain any remaining content
|
||||
bufStr := p.buffer.String()
|
||||
p.buffer.Reset()
|
||||
if p.state == olmo3StateContent && len(bufStr) > 0 {
|
||||
return bufStr, "", nil, nil
|
||||
}
|
||||
return "", "", nil, nil
|
||||
}
|
||||
|
||||
events := p.parseEvents()
|
||||
|
||||
var contentSb strings.Builder
|
||||
var allCalls []api.ToolCall
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case olmo3ParserEventContent:
|
||||
contentSb.WriteString(event.content)
|
||||
case olmo3ParserEventToolCalls:
|
||||
allCalls = append(allCalls, event.calls...)
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), "", allCalls, nil
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
|
||||
var all []olmo3ParserEvent
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []olmo3ParserEvent
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
|
||||
var events []olmo3ParserEvent
|
||||
bufStr := p.buffer.String()
|
||||
if bufStr == "" {
|
||||
return events, false
|
||||
}
|
||||
|
||||
switch p.state {
|
||||
case olmo3StateContent:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
|
||||
// Found <function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
|
||||
content := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCalls
|
||||
|
||||
if len(content) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: content})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
|
||||
// Partial <function_calls> tag - withhold ambiguous content
|
||||
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
// Regular content - emit all
|
||||
p.buffer.Reset()
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
case olmo3StateToolCalls:
|
||||
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
|
||||
// Found </function_calls> tag
|
||||
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
|
||||
toolCallsStr := split[0]
|
||||
remaining := split[1]
|
||||
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(remaining)
|
||||
p.state = olmo3StateToolCallsDone
|
||||
|
||||
// Parse the function calls
|
||||
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
|
||||
if err != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
|
||||
} else if len(calls) > 0 {
|
||||
events = append(events, olmo3ParserEventToolCalls{calls: calls})
|
||||
}
|
||||
return events, true
|
||||
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
|
||||
// Partial </function_calls> tag - wait for more
|
||||
return events, false
|
||||
}
|
||||
// Still collecting tool calls, wait for close tag
|
||||
return events, false
|
||||
|
||||
case olmo3StateToolCallsDone:
|
||||
// After tool calls, emit remaining content
|
||||
p.buffer.Reset()
|
||||
p.state = olmo3StateContent
|
||||
if len(bufStr) > 0 {
|
||||
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
|
||||
return events, false
|
||||
}
|
||||
|
||||
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
|
||||
// func_name(arg1="value1", arg2=123)
|
||||
// Multiple calls are separated by newlines
|
||||
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
|
||||
var calls []api.ToolCall
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Split by newlines for multiple function calls
|
||||
lines := strings.Split(s, "\n")
|
||||
for _, line := range lines {
|
||||
line = strings.TrimSpace(line)
|
||||
if line == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
call, err := parseOlmo3SingleFunctionCall(line)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
|
||||
}
|
||||
calls = append(calls, call)
|
||||
}
|
||||
|
||||
return calls, nil
|
||||
}
|
||||
|
||||
// Regex to match function call: func_name(args)
|
||||
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
|
||||
|
||||
// Regex to match a single argument: key=value
|
||||
// Value can be: "string", 'string', number, true, false, null, or nested structures
|
||||
var argRegex = regexp.MustCompile(`^(\w+)=(.+)$`)
|
||||
|
||||
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||
matches := funcCallRegex.FindStringSubmatch(s)
|
||||
if matches == nil {
|
||||
return api.ToolCall{}, fmt.Errorf("invalid function call format")
|
||||
}
|
||||
|
||||
funcName := matches[1]
|
||||
argsStr := matches[2]
|
||||
|
||||
args, err := parseOlmo3Arguments(argsStr)
|
||||
if err != nil {
|
||||
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
|
||||
}
|
||||
|
||||
return api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: funcName,
|
||||
Arguments: args,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||
args := make(map[string]any)
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// Split by commas, but respect nested structures and quotes
|
||||
parts := splitArguments(s)
|
||||
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find the first = sign
|
||||
eqIdx := strings.Index(part, "=")
|
||||
if eqIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||
}
|
||||
|
||||
key := strings.TrimSpace(part[:eqIdx])
|
||||
valueStr := strings.TrimSpace(part[eqIdx+1:])
|
||||
|
||||
value, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||
}
|
||||
|
||||
args[key] = value
|
||||
}
|
||||
|
||||
return args, nil
|
||||
}
|
||||
|
||||
// splitArguments splits arguments by commas, respecting quotes and nested structures
|
||||
func splitArguments(s string) []string {
|
||||
var parts []string
|
||||
var current strings.Builder
|
||||
depth := 0
|
||||
inString := false
|
||||
stringChar := byte(0)
|
||||
escaped := false
|
||||
|
||||
for i := 0; i < len(s); i++ {
|
||||
c := s[i]
|
||||
|
||||
if escaped {
|
||||
current.WriteByte(c)
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
|
||||
if c == '\\' && inString {
|
||||
current.WriteByte(c)
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
|
||||
if (c == '"' || c == '\'') && !inString {
|
||||
inString = true
|
||||
stringChar = c
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if c == stringChar && inString {
|
||||
inString = false
|
||||
stringChar = 0
|
||||
current.WriteByte(c)
|
||||
continue
|
||||
}
|
||||
|
||||
if !inString {
|
||||
switch c {
|
||||
case '(', '[', '{':
|
||||
depth++
|
||||
current.WriteByte(c)
|
||||
case ')', ']', '}':
|
||||
depth--
|
||||
current.WriteByte(c)
|
||||
case ',':
|
||||
if depth == 0 {
|
||||
parts = append(parts, current.String())
|
||||
current.Reset()
|
||||
continue
|
||||
}
|
||||
current.WriteByte(c)
|
||||
default:
|
||||
current.WriteByte(c)
|
||||
}
|
||||
} else {
|
||||
current.WriteByte(c)
|
||||
}
|
||||
}
|
||||
|
||||
if current.Len() > 0 {
|
||||
parts = append(parts, current.String())
|
||||
}
|
||||
|
||||
return parts
|
||||
}
|
||||
|
||||
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
|
||||
func parseOlmo3Value(s string) (any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
|
||||
// Check for quoted string
|
||||
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
|
||||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
|
||||
// Remove quotes and unescape
|
||||
inner := s[1 : len(s)-1]
|
||||
return unescapeString(inner), nil
|
||||
}
|
||||
|
||||
// Check for boolean
|
||||
if s == "true" || s == "True" {
|
||||
return true, nil
|
||||
}
|
||||
if s == "false" || s == "False" {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Check for null/None
|
||||
if s == "null" || s == "None" || s == "nil" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Check for number
|
||||
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return i, nil
|
||||
}
|
||||
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
// Check for array [...]
|
||||
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
|
||||
return parseOlmo3Array(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Check for object {...}
|
||||
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
|
||||
return parseOlmo3Object(s[1 : len(s)-1])
|
||||
}
|
||||
|
||||
// Default to string without quotes
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Array(s string) ([]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return []any{}, nil
|
||||
}
|
||||
|
||||
parts := splitArguments(s)
|
||||
var arr []any
|
||||
for _, part := range parts {
|
||||
val, err := parseOlmo3Value(part)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
arr = append(arr, val)
|
||||
}
|
||||
return arr, nil
|
||||
}
|
||||
|
||||
func parseOlmo3Object(s string) (map[string]any, error) {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return map[string]any{}, nil
|
||||
}
|
||||
|
||||
// Objects use key: value or "key": value format
|
||||
obj := make(map[string]any)
|
||||
parts := splitArguments(s)
|
||||
for _, part := range parts {
|
||||
part = strings.TrimSpace(part)
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Find colon separator
|
||||
colonIdx := strings.Index(part, ":")
|
||||
if colonIdx == -1 {
|
||||
return nil, fmt.Errorf("invalid object entry: %s", part)
|
||||
}
|
||||
|
||||
keyStr := strings.TrimSpace(part[:colonIdx])
|
||||
valueStr := strings.TrimSpace(part[colonIdx+1:])
|
||||
|
||||
// Remove quotes from key if present
|
||||
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
|
||||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
|
||||
keyStr = keyStr[1 : len(keyStr)-1]
|
||||
}
|
||||
|
||||
val, err := parseOlmo3Value(valueStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
|
||||
}
|
||||
|
||||
obj[keyStr] = val
|
||||
}
|
||||
|
||||
return obj, nil
|
||||
}
|
||||
|
||||
func unescapeString(s string) string {
|
||||
// Handle common escape sequences
|
||||
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
|
||||
s = strings.ReplaceAll(s, `\"`, `"`)
|
||||
s = strings.ReplaceAll(s, `\'`, `'`)
|
||||
s = strings.ReplaceAll(s, `\n`, "\n")
|
||||
s = strings.ReplaceAll(s, `\t`, "\t")
|
||||
s = strings.ReplaceAll(s, `\r`, "\r")
|
||||
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
|
||||
return s
|
||||
}
|
||||
483
model/parsers/olmo3_test.go
Normal file
483
model/parsers/olmo3_test.go
Normal file
@@ -0,0 +1,483 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Parser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expectedContent string
|
||||
expectedThinking string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "simple content",
|
||||
input: "Hello, how can I help you?",
|
||||
expectedContent: "Hello, how can I help you?",
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "content then tool call",
|
||||
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
|
||||
expectedContent: "Let me check the weather.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
"date": "2024-01-15",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple tool calls",
|
||||
input: `<function_calls>get_weather(location="San Francisco")
|
||||
get_weather(location="New York")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with numeric argument",
|
||||
input: `<function_calls>set_temperature(value=72)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_temperature",
|
||||
Arguments: map[string]any{"value": int64(72)},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with float argument",
|
||||
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "set_price",
|
||||
Arguments: map[string]any{"amount": 19.99},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with boolean argument",
|
||||
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "toggle_setting",
|
||||
Arguments: map[string]any{"enabled": true},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with null argument",
|
||||
input: `<function_calls>clear_value(field=null)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "clear_value",
|
||||
Arguments: map[string]any{"field": nil},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with array argument",
|
||||
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "process_items",
|
||||
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with dict argument",
|
||||
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "update_config",
|
||||
Arguments: map[string]any{
|
||||
"settings": map[string]any{
|
||||
"theme": "dark",
|
||||
"fontSize": int64(14),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with nested dict",
|
||||
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_request",
|
||||
Arguments: map[string]any{
|
||||
"data": map[string]any{
|
||||
"user": map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
},
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with no arguments",
|
||||
input: `<function_calls>get_current_time()</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_time",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with single quotes",
|
||||
input: `<function_calls>search(query='hello world')</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": "hello world"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with escaped quotes",
|
||||
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "search",
|
||||
Arguments: map[string]any{"query": `say "hello"`},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call with mixed argument types",
|
||||
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "create_user",
|
||||
Arguments: map[string]any{
|
||||
"name": "John",
|
||||
"age": int64(30),
|
||||
"active": true,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
content, thinking, calls, err := p.Add(tt.input, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// Drain remaining content
|
||||
finalContent, finalThinking, finalCalls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
content += finalContent
|
||||
thinking += finalThinking
|
||||
calls = append(calls, finalCalls...)
|
||||
|
||||
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
chunks []string
|
||||
expectedContent string
|
||||
expectedCalls []api.ToolCall
|
||||
}{
|
||||
{
|
||||
name: "streaming content",
|
||||
chunks: []string{"Hello, ", "how ", "can I help?"},
|
||||
expectedContent: "Hello, how can I help?",
|
||||
},
|
||||
{
|
||||
name: "streaming tool call",
|
||||
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "streaming content then tool call",
|
||||
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
|
||||
expectedContent: "Let me check.",
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "NYC"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool call tag split across chunks",
|
||||
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
|
||||
expectedCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "test",
|
||||
Arguments: map[string]any{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
p.Init(nil, nil, nil)
|
||||
|
||||
var allContent string
|
||||
var allCalls []api.ToolCall
|
||||
|
||||
for _, chunk := range tt.chunks {
|
||||
content, _, calls, err := p.Add(chunk, false)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
}
|
||||
|
||||
// Drain
|
||||
content, _, calls, err := p.Add("", true)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error on done: %v", err)
|
||||
}
|
||||
allContent += content
|
||||
allCalls = append(allCalls, calls...)
|
||||
|
||||
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if !p.HasToolSupport() {
|
||||
t.Error("expected HasToolSupport to return true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
|
||||
p := &Olmo3Parser{}
|
||||
if p.HasThinkingSupport() {
|
||||
t.Error("expected HasThinkingSupport to return false")
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected []api.ToolCall
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "simple call",
|
||||
input: `get_weather(location="SF")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple args",
|
||||
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "send_email",
|
||||
Arguments: map[string]any{
|
||||
"to": "user@example.com",
|
||||
"subject": "Hello",
|
||||
"body": "Test message",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple calls with newlines",
|
||||
input: `get_weather(location="SF")
|
||||
get_time(timezone="PST")`,
|
||||
expected: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "SF"},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_time",
|
||||
Arguments: map[string]any{"timezone": "PST"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
expected: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
input: " \n ",
|
||||
expected: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
calls, err := parseOlmo3FunctionCalls(tt.input)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseOlmo3Value(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected any
|
||||
}{
|
||||
{"string double quotes", `"hello"`, "hello"},
|
||||
{"string single quotes", `'hello'`, "hello"},
|
||||
{"integer", "42", int64(42)},
|
||||
{"negative integer", "-10", int64(-10)},
|
||||
{"float", "3.14", 3.14},
|
||||
{"boolean true", "true", true},
|
||||
{"boolean True", "True", true},
|
||||
{"boolean false", "false", false},
|
||||
{"null", "null", nil},
|
||||
{"None", "None", nil},
|
||||
{"empty array", "[]", []any{}},
|
||||
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
|
||||
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
|
||||
{"empty object", "{}", map[string]any{}},
|
||||
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
|
||||
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
|
||||
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := parseOlmo3Value(tt.input)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(result, tt.expected); diff != "" {
|
||||
t.Errorf("value mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,9 +1,6 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
)
|
||||
@@ -41,27 +38,27 @@ func ParserForName(name string) Parser {
|
||||
if parser, ok := registry.constructors[name]; ok {
|
||||
return parser()
|
||||
}
|
||||
var p Parser
|
||||
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
p = &Qwen3CoderParser{}
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "qwen3-vl-instruct":
|
||||
p = &Qwen3VLParser{hasThinkingSupport: false}
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: false}
|
||||
return parser
|
||||
case "qwen3-vl-thinking":
|
||||
p = &Qwen3VLParser{hasThinkingSupport: true}
|
||||
case "ministral":
|
||||
p = &MinistralParser{hasThinkingSupport: false}
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: true}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
return harmony.NewHarmonyMessageHandler()
|
||||
case "cogito":
|
||||
return &CogitoParser{}
|
||||
case "olmo3":
|
||||
return &Olmo3Parser{}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
type PassthroughParser struct{}
|
||||
@@ -81,20 +78,3 @@ func (p *PassthroughParser) HasToolSupport() bool {
|
||||
func (p *PassthroughParser) HasThinkingSupport() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func splitAtTag(sb *strings.Builder, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(sb.String(), tag, 2)
|
||||
if len(split) == 1 {
|
||||
sb.Reset()
|
||||
return split[0], ""
|
||||
}
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
sb.Reset()
|
||||
sb.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -96,164 +95,3 @@ func TestUnknownParserReturnsNil(t *testing.T) {
|
||||
t.Error("expected nil for unknown parser")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitAtTag(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
tag string
|
||||
trimAfter bool
|
||||
wantBefore string
|
||||
wantAfter string
|
||||
wantSB string // expected content of strings.Builder after operation
|
||||
}{
|
||||
{
|
||||
name: "basic split with trimAfter true",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "basic split with trimAfter false",
|
||||
input: "hello <!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter true",
|
||||
input: "<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "tag at beginning with trimAfter false",
|
||||
input: "<!-- split --> world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "",
|
||||
wantAfter: " world",
|
||||
wantSB: " world",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter true",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "tag at end with trimAfter false",
|
||||
input: "hello <!-- split -->",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "multiple tags splits at first occurrence",
|
||||
input: "hello <!-- split --> world <!-- split --> end",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "world <!-- split --> end",
|
||||
wantSB: "world <!-- split --> end",
|
||||
},
|
||||
{
|
||||
name: "tag not present",
|
||||
input: "hello world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello world",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: "",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace before tag",
|
||||
input: " \t\n<!-- split -->world",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "",
|
||||
wantAfter: "world",
|
||||
wantSB: "world",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter true",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: "hello",
|
||||
wantAfter: "",
|
||||
wantSB: "",
|
||||
},
|
||||
{
|
||||
name: "only whitespace after tag with trimAfter false",
|
||||
input: "hello<!-- split --> \t\n",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: false,
|
||||
wantBefore: "hello",
|
||||
wantAfter: " \t\n",
|
||||
wantSB: " \t\n",
|
||||
},
|
||||
{
|
||||
name: "complex whitespace trimming",
|
||||
input: " hello \t\n <!-- split --> \n\t world ",
|
||||
tag: "<!-- split -->",
|
||||
trimAfter: true,
|
||||
wantBefore: " hello",
|
||||
wantAfter: "world ",
|
||||
wantSB: "world ",
|
||||
},
|
||||
{
|
||||
name: "tag with special characters",
|
||||
input: "text <tag attr=\"value\"> more text",
|
||||
tag: "<tag attr=\"value\">",
|
||||
trimAfter: true,
|
||||
wantBefore: "text",
|
||||
wantAfter: "more text",
|
||||
wantSB: "more text",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sb := &strings.Builder{}
|
||||
sb.WriteString(tt.input)
|
||||
|
||||
before, after := splitAtTag(sb, tt.tag, tt.trimAfter)
|
||||
|
||||
// Check return values
|
||||
if before != tt.wantBefore {
|
||||
t.Errorf("splitAtTag() before = %q, want %q", before, tt.wantBefore)
|
||||
}
|
||||
if after != tt.wantAfter {
|
||||
t.Errorf("splitAtTag() after = %q, want %q", after, tt.wantAfter)
|
||||
}
|
||||
|
||||
// Check strings.Builder state
|
||||
if sb.String() != tt.wantSB {
|
||||
t.Errorf("strings.Builder after split = %q, want %q", sb.String(), tt.wantSB)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,6 +70,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
@@ -80,7 +81,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
calls = append(calls, toolCall)
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventThinkingContent:
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwenEventContent:
|
||||
@@ -90,7 +91,7 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
}
|
||||
}
|
||||
|
||||
return contentSb.String(), thinkingSb.String(), calls, nil
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
@@ -112,6 +113,19 @@ func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
return all
|
||||
}
|
||||
|
||||
func splitAtTag(p *Qwen3VLParser, tag string, trimAfter bool) (string, string) {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
after := split[1]
|
||||
if trimAfter {
|
||||
after = strings.TrimLeftFunc(after, unicode.IsSpace)
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return before, after // return events
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) eatLeadingWhitespaceAndTransitionTo(nextState qwenParserState) ([]qwenEvent, bool) {
|
||||
trimmed := strings.TrimLeftFunc(p.buffer.String(), unicode.IsSpace)
|
||||
p.buffer.Reset()
|
||||
@@ -130,7 +144,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
case CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
// events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||
before, _ := splitAtTag(&p.buffer, toolOpenTag, false)
|
||||
before, _ := splitAtTag(p, toolOpenTag, false)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventContent{content: before})
|
||||
}
|
||||
@@ -181,7 +195,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
}
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
thinking, remaining := splitAtTag(&p.buffer, thinkingCloseTag, true)
|
||||
thinking, remaining := splitAtTag(p, thinkingCloseTag, true)
|
||||
if len(thinking) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: thinking})
|
||||
}
|
||||
|
||||
147
model/renderers/olmo3.go
Normal file
147
model/renderers/olmo3.go
Normal file
@@ -0,0 +1,147 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||
)
|
||||
|
||||
type Olmo3Renderer struct{}
|
||||
|
||||
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
var systemMessage *api.Message
|
||||
filteredMessages := make([]api.Message, 0, len(messages))
|
||||
for i, message := range messages {
|
||||
if message.Role == "system" {
|
||||
if systemMessage == nil {
|
||||
systemMessage = &messages[i]
|
||||
}
|
||||
continue
|
||||
}
|
||||
filteredMessages = append(filteredMessages, message)
|
||||
}
|
||||
|
||||
// Render system message
|
||||
if systemMessage != nil {
|
||||
// Custom system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(systemMessage.Content)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
} else {
|
||||
// Default system message - single newline after "system"
|
||||
sb.WriteString("<|im_start|>system\n")
|
||||
sb.WriteString(olmo3DefaultSystemMessage)
|
||||
|
||||
if len(tools) > 0 {
|
||||
functionsJSON, err := marshalWithSpaces(tools)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(olmo3WithFunctionsMessage)
|
||||
sb.WriteString("<functions>")
|
||||
sb.WriteString(string(functionsJSON))
|
||||
sb.WriteString("</functions>")
|
||||
} else {
|
||||
sb.WriteString(olmo3NoFunctionsMessage)
|
||||
sb.WriteString("<functions></functions>")
|
||||
}
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
for i, message := range filteredMessages {
|
||||
lastMessage := i == len(filteredMessages)-1
|
||||
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|im_start|>user\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
|
||||
case "assistant":
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
|
||||
if message.Content != "" {
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
|
||||
if len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<function_calls>")
|
||||
for j, tc := range message.ToolCalls {
|
||||
// Format as function_name(arg1="value1", arg2="value2")
|
||||
sb.WriteString(tc.Function.Name)
|
||||
sb.WriteString("(")
|
||||
|
||||
// Get sorted keys for deterministic output
|
||||
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||
for k := range tc.Function.Arguments {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for k, key := range keys {
|
||||
if k > 0 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
|
||||
}
|
||||
sb.WriteString(")")
|
||||
|
||||
if j < len(message.ToolCalls)-1 {
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
}
|
||||
sb.WriteString("</function_calls>")
|
||||
}
|
||||
|
||||
// Add end tag unless it's the last message with content only (prefill)
|
||||
if !lastMessage || len(message.ToolCalls) > 0 {
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
|
||||
case "tool":
|
||||
sb.WriteString("<|im_start|>environment\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("<|im_end|>\n")
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt if needed
|
||||
needsGenerationPrompt := true
|
||||
if len(filteredMessages) > 0 {
|
||||
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||
needsGenerationPrompt = false
|
||||
}
|
||||
}
|
||||
|
||||
if needsGenerationPrompt {
|
||||
sb.WriteString("<|im_start|>assistant\n")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
290
model/renderers/olmo3_test.go
Normal file
290
model/renderers/olmo3_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestOlmo3Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
tools []api.Tool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic without system - adds default system",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message no tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello!<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "with system message and tools",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "default system with tools - includes function instruction",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "What is the weather?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool calls - function call syntax",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "What is the weather in SF?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let me check the weather.",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "San Francisco",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"What is the weather in SF?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "multi-turn conversation",
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
{Role: "user", Content: "How are you?"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful assistant.<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!<|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"How are you?<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "parallel tool calls - newline separated",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "San Francisco"},
|
||||
},
|
||||
},
|
||||
{
|
||||
ID: "call_2",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: map[string]any{"location": "New York"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Get weather in SF and NYC<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>get_weather(location="San Francisco")` + "\n" +
|
||||
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||
"<|im_start|>environment\n" +
|
||||
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "tool call with multiple arguments",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Book a flight"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
ID: "call_1",
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "book_flight",
|
||||
Arguments: map[string]any{
|
||||
"from": "SFO",
|
||||
"to": "NYC",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "book_flight",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"from": {Type: api.PropertyType{"string"}},
|
||||
"to": {Type: api.PropertyType{"string"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. " +
|
||||
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Book a flight<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||
"<|im_start|>assistant\n\n",
|
||||
},
|
||||
{
|
||||
name: "assistant prefill - no generation prompt",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Hello"},
|
||||
{Role: "assistant", Content: "Hi there!"},
|
||||
},
|
||||
expected: "<|im_start|>system\n" +
|
||||
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||
"<|im_start|>user\n" +
|
||||
"Hello<|im_end|>\n" +
|
||||
"<|im_start|>assistant\n" +
|
||||
"Hi there!",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -59,6 +59,9 @@ func rendererForName(name string) Renderer {
|
||||
case "cogito":
|
||||
renderer := &CogitoRenderer{isThinking: true}
|
||||
return renderer
|
||||
case "olmo3":
|
||||
renderer := &Olmo3Renderer{}
|
||||
return renderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -110,6 +110,7 @@ func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.Thi
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
slog.Debug("rendered prompt", "renderer", m.Config.Renderer, "prompt", rendered)
|
||||
return rendered, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user