Compare commits
10 Commits
jmorganca/
...
parth/olmo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c3bf414ef | ||
|
|
0a9862a383 | ||
|
|
f475cc365a | ||
|
|
dd3306d3a0 | ||
|
|
57c1d7db9a | ||
|
|
91d6370a62 | ||
|
|
38a2a6468f | ||
|
|
064ec63ddf | ||
|
|
fd959fbf7a | ||
|
|
cfc9729edf |
@@ -1,50 +0,0 @@
|
|||||||
# eval
|
|
||||||
|
|
||||||
Evaluation tool for testing Ollama models.
|
|
||||||
|
|
||||||
## Usage
|
|
||||||
|
|
||||||
Run all tests:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go run . -model llama3.2:latest
|
|
||||||
```
|
|
||||||
|
|
||||||
Run specific suite:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go run . -model llama3.2:latest -suite tool-calling-basic -v
|
|
||||||
```
|
|
||||||
|
|
||||||
List available suites:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
go run . -list
|
|
||||||
```
|
|
||||||
|
|
||||||
## Adding Tests
|
|
||||||
|
|
||||||
Edit `suites.go` to add new test suites. Each test needs:
|
|
||||||
|
|
||||||
- `Name`: test identifier
|
|
||||||
- `Prompt`: what to send to the model
|
|
||||||
- `Check`: function to validate the response
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
```go
|
|
||||||
{
|
|
||||||
Name: "my-test",
|
|
||||||
Prompt: "What is 2+2?",
|
|
||||||
Check: Contains("4"),
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Available check functions:
|
|
||||||
|
|
||||||
- `HasResponse()` - response is non-empty
|
|
||||||
- `Contains(s)` - response contains substring
|
|
||||||
- `CallsTool(name)` - model called specific tool
|
|
||||||
- `NoTools()` - model called no tools
|
|
||||||
- `MinTools(n)` - model called at least n tools
|
|
||||||
- `All(checks...)` - all checks pass
|
|
||||||
151
cmd/eval/eval.go
151
cmd/eval/eval.go
@@ -1,151 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Test is a single evaluation test
|
|
||||||
type Test struct {
|
|
||||||
Name string
|
|
||||||
Prompt string
|
|
||||||
System string
|
|
||||||
Tools []api.Tool
|
|
||||||
Think bool
|
|
||||||
Options map[string]any
|
|
||||||
Check func(response string, tools []api.ToolCall) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Suite is a collection of tests
|
|
||||||
type Suite struct {
|
|
||||||
Name string
|
|
||||||
Tests []Test
|
|
||||||
}
|
|
||||||
|
|
||||||
// Result holds test execution results
|
|
||||||
type Result struct {
|
|
||||||
Name string
|
|
||||||
Passed bool
|
|
||||||
Error error
|
|
||||||
Duration time.Duration
|
|
||||||
Response string
|
|
||||||
Tools []string
|
|
||||||
ToolCalls []api.ToolCall
|
|
||||||
Thinking bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run executes a test against a model
|
|
||||||
func Run(ctx context.Context, client *api.Client, model string, test Test) Result {
|
|
||||||
result := Result{Name: test.Name}
|
|
||||||
|
|
||||||
req := &api.ChatRequest{
|
|
||||||
Model: model,
|
|
||||||
Messages: []api.Message{
|
|
||||||
{Role: "user", Content: test.Prompt},
|
|
||||||
},
|
|
||||||
Options: test.Options,
|
|
||||||
}
|
|
||||||
|
|
||||||
if test.System != "" {
|
|
||||||
req.Messages = append([]api.Message{
|
|
||||||
{Role: "system", Content: test.System},
|
|
||||||
}, req.Messages...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(test.Tools) > 0 {
|
|
||||||
req.Tools = test.Tools
|
|
||||||
}
|
|
||||||
|
|
||||||
if test.Think {
|
|
||||||
req.Think = &api.ThinkValue{Value: true}
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp strings.Builder
|
|
||||||
var toolCalls []api.ToolCall
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
err := client.Chat(ctx, req, func(r api.ChatResponse) error {
|
|
||||||
resp.WriteString(r.Message.Content)
|
|
||||||
if r.Message.Thinking != "" {
|
|
||||||
result.Thinking = true
|
|
||||||
}
|
|
||||||
toolCalls = append(toolCalls, r.Message.ToolCalls...)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
result.Duration = time.Since(start)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
result.Error = err
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
result.Response = resp.String()
|
|
||||||
result.Tools = uniqueToolNames(toolCalls)
|
|
||||||
result.ToolCalls = toolCalls
|
|
||||||
result.Passed = test.Check(result.Response, toolCalls)
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func uniqueToolNames(calls []api.ToolCall) []string {
|
|
||||||
seen := make(map[string]bool)
|
|
||||||
var names []string
|
|
||||||
for _, c := range calls {
|
|
||||||
if !seen[c.Function.Name] {
|
|
||||||
seen[c.Function.Name] = true
|
|
||||||
names = append(names, c.Function.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return names
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check functions for common test patterns
|
|
||||||
|
|
||||||
func HasResponse() func(string, []api.ToolCall) bool {
|
|
||||||
return func(resp string, _ []api.ToolCall) bool {
|
|
||||||
return strings.TrimSpace(resp) != ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Contains(s string) func(string, []api.ToolCall) bool {
|
|
||||||
return func(resp string, _ []api.ToolCall) bool {
|
|
||||||
return strings.Contains(strings.ToLower(resp), strings.ToLower(s))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CallsTool(name string) func(string, []api.ToolCall) bool {
|
|
||||||
return func(_ string, tools []api.ToolCall) bool {
|
|
||||||
for _, t := range tools {
|
|
||||||
if t.Function.Name == name {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NoTools() func(string, []api.ToolCall) bool {
|
|
||||||
return func(_ string, tools []api.ToolCall) bool {
|
|
||||||
return len(tools) == 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func MinTools(n int) func(string, []api.ToolCall) bool {
|
|
||||||
return func(_ string, tools []api.ToolCall) bool {
|
|
||||||
return len(tools) >= n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func All(checks ...func(string, []api.ToolCall) bool) func(string, []api.ToolCall) bool {
|
|
||||||
return func(resp string, tools []api.ToolCall) bool {
|
|
||||||
for _, check := range checks {
|
|
||||||
if !check(resp, tools) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
217
cmd/eval/main.go
217
cmd/eval/main.go
@@ -1,217 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
"flag"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func main() {
|
|
||||||
model := flag.String("model", "", "model to evaluate")
|
|
||||||
suite := flag.String("suite", "", "comma-separated list of suites to run (empty runs all)")
|
|
||||||
list := flag.Bool("list", false, "list available suites")
|
|
||||||
verbose := flag.Bool("v", false, "verbose output")
|
|
||||||
timeout := flag.Int("timeout", 60, "timeout per test in seconds")
|
|
||||||
export := flag.String("export", "eval-results.json", "export results to file")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
if *list {
|
|
||||||
for _, s := range suites {
|
|
||||||
fmt.Printf("%s (%d tests)\n", s.Name, len(s.Tests))
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if *model == "" {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: -model parameter is required\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
client, err := api.ClientFromEnvironment()
|
|
||||||
if err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: %v\n", err)
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
if err := client.Heartbeat(ctx); err != nil {
|
|
||||||
cancel()
|
|
||||||
fmt.Fprintf(os.Stderr, "error: cannot connect to ollama\n")
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
selected := suites
|
|
||||||
if *suite != "" {
|
|
||||||
suiteNames := strings.Split(*suite, ",")
|
|
||||||
selected = []Suite{}
|
|
||||||
var notFound []string
|
|
||||||
|
|
||||||
for _, name := range suiteNames {
|
|
||||||
name = strings.TrimSpace(name)
|
|
||||||
if name == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
found := false
|
|
||||||
for _, s := range suites {
|
|
||||||
if s.Name == name {
|
|
||||||
selected = append(selected, s)
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
notFound = append(notFound, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(notFound) > 0 {
|
|
||||||
fmt.Fprintf(os.Stderr, "error: suite(s) not found: %s\n", strings.Join(notFound, ", "))
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var results []Result
|
|
||||||
for _, s := range selected {
|
|
||||||
if *verbose {
|
|
||||||
fmt.Printf("\n%s (%d tests)\n", s.Name, len(s.Tests))
|
|
||||||
}
|
|
||||||
for i, test := range s.Tests {
|
|
||||||
if test.Options == nil {
|
|
||||||
test.Options = map[string]any{"temperature": 0.1}
|
|
||||||
}
|
|
||||||
if test.Check == nil {
|
|
||||||
test.Check = HasResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
if *verbose {
|
|
||||||
fmt.Printf(" [%d/%d] %s... ", i+1, len(s.Tests), test.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*timeout)*time.Second)
|
|
||||||
result := Run(ctx, client, *model, test)
|
|
||||||
cancel()
|
|
||||||
|
|
||||||
results = append(results, result)
|
|
||||||
|
|
||||||
if *verbose {
|
|
||||||
if result.Error != nil {
|
|
||||||
fmt.Printf("ERROR: %v\n", result.Error)
|
|
||||||
} else if result.Passed {
|
|
||||||
fmt.Printf("PASS (%.2fs)", result.Duration.Seconds())
|
|
||||||
if len(result.Tools) > 0 || result.Thinking {
|
|
||||||
fmt.Printf(" [")
|
|
||||||
if len(result.Tools) > 0 {
|
|
||||||
fmt.Printf("tools: %s", strings.Join(result.Tools, ","))
|
|
||||||
}
|
|
||||||
if result.Thinking {
|
|
||||||
if len(result.Tools) > 0 {
|
|
||||||
fmt.Printf(", ")
|
|
||||||
}
|
|
||||||
fmt.Printf("thinking")
|
|
||||||
}
|
|
||||||
fmt.Printf("]")
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
|
|
||||||
// Print tool calls with details
|
|
||||||
if len(result.ToolCalls) > 0 {
|
|
||||||
fmt.Printf(" Tool Calls:\n")
|
|
||||||
for _, tc := range result.ToolCalls {
|
|
||||||
argsJSON, _ := json.Marshal(tc.Function.Arguments)
|
|
||||||
fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print response if there is one
|
|
||||||
if result.Response != "" {
|
|
||||||
fmt.Printf(" Response: %s\n", result.Response)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fmt.Printf("FAIL (%.2fs)\n", result.Duration.Seconds())
|
|
||||||
|
|
||||||
// Print tool calls with details even on failure
|
|
||||||
if len(result.ToolCalls) > 0 {
|
|
||||||
fmt.Printf(" Tool Calls:\n")
|
|
||||||
for _, tc := range result.ToolCalls {
|
|
||||||
argsJSON, _ := json.Marshal(tc.Function.Arguments)
|
|
||||||
fmt.Printf(" - %s: %s\n", tc.Function.Name, string(argsJSON))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Print response even on failure
|
|
||||||
if result.Response != "" {
|
|
||||||
fmt.Printf(" Response: %s\n", result.Response)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
printSummary(results)
|
|
||||||
|
|
||||||
if *export != "" {
|
|
||||||
if err := writeJSON(*export, results); err != nil {
|
|
||||||
fmt.Fprintf(os.Stderr, "warning: export failed: %v\n", err)
|
|
||||||
} else if *verbose {
|
|
||||||
fmt.Printf("\nResults: %s\n", *export)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if anyFailed(results) {
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printSummary(results []Result) {
|
|
||||||
var passed, failed, errors int
|
|
||||||
for _, r := range results {
|
|
||||||
if r.Error != nil {
|
|
||||||
errors++
|
|
||||||
} else if r.Passed {
|
|
||||||
passed++
|
|
||||||
} else {
|
|
||||||
failed++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
total := len(results)
|
|
||||||
rate := 0.0
|
|
||||||
if total > 0 {
|
|
||||||
rate = float64(passed) / float64(total) * 100
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("\n%d/%d passed (%.1f%%)", passed, total, rate)
|
|
||||||
if errors > 0 {
|
|
||||||
fmt.Printf(", %d errors", errors)
|
|
||||||
}
|
|
||||||
fmt.Println()
|
|
||||||
}
|
|
||||||
|
|
||||||
func anyFailed(results []Result) bool {
|
|
||||||
for _, r := range results {
|
|
||||||
if !r.Passed || r.Error != nil {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func writeJSON(path string, results []Result) error {
|
|
||||||
f, err := os.Create(path)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
enc := json.NewEncoder(f)
|
|
||||||
enc.SetIndent("", " ")
|
|
||||||
return enc.Encode(results)
|
|
||||||
}
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import "github.com/ollama/ollama/api"
|
|
||||||
|
|
||||||
var suites = []Suite{
|
|
||||||
{
|
|
||||||
Name: "basic-qa",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "simple-math",
|
|
||||||
Prompt: "What is 2+2? Reply with just the number.",
|
|
||||||
Check: Contains("4"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "capital-city",
|
|
||||||
Prompt: "What is the capital of France? Reply with just the city name.",
|
|
||||||
Check: Contains("Paris"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "greeting",
|
|
||||||
Prompt: "Say hello",
|
|
||||||
Check: HasResponse(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "reasoning",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "logic-puzzle",
|
|
||||||
Prompt: "If all roses are flowers and some flowers fade quickly, can we conclude that some roses fade quickly? Answer yes or no.",
|
|
||||||
Check: Contains("no"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "counting",
|
|
||||||
Prompt: "How many letters are in the word 'HELLO'?",
|
|
||||||
Check: Contains("5"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "instruction-following",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "json-output",
|
|
||||||
Prompt: "Reply with a JSON object containing a 'status' field set to 'ok'.",
|
|
||||||
Check: All(Contains("status"), Contains("ok")),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "system-prompt",
|
|
||||||
Prompt: "What is your name?",
|
|
||||||
System: "You are a helpful assistant named TestBot. When asked your name, always respond with 'TestBot'.",
|
|
||||||
Check: Contains("TestBot"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "tool-calling-basic",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "single-tool",
|
|
||||||
Prompt: "What's the weather like in San Francisco?",
|
|
||||||
Tools: []api.Tool{weatherTool},
|
|
||||||
Check: CallsTool("get_weather"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "tool-selection",
|
|
||||||
Prompt: "What time is it in Tokyo?",
|
|
||||||
Tools: []api.Tool{weatherTool, timeTool},
|
|
||||||
Check: CallsTool("get_time"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "no-tool-needed",
|
|
||||||
Prompt: "What is 2+2?",
|
|
||||||
Tools: []api.Tool{weatherTool, timeTool},
|
|
||||||
Check: NoTools(),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "tool-calling-advanced",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "parallel-calls",
|
|
||||||
Prompt: "Get the weather in both New York and Los Angeles.",
|
|
||||||
Tools: []api.Tool{weatherTool},
|
|
||||||
Check: All(CallsTool("get_weather"), MinTools(2)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "multi-param",
|
|
||||||
Prompt: "Search for Italian restaurants with prices between $20 and $40.",
|
|
||||||
Tools: []api.Tool{restaurantTool},
|
|
||||||
Check: CallsTool("search_restaurants"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "tool-calling-thinking",
|
|
||||||
Tests: []Test{
|
|
||||||
{
|
|
||||||
Name: "thinking-before-tool",
|
|
||||||
Prompt: "I need to know the weather in Paris before I decide what to pack.",
|
|
||||||
Tools: []api.Tool{weatherTool},
|
|
||||||
Think: true,
|
|
||||||
Check: CallsTool("get_weather"),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Name: "thinking-multi-tool",
|
|
||||||
Prompt: "I'm planning a trip to London. I need to know what time it is there and what the weather is like.",
|
|
||||||
Tools: []api.Tool{weatherTool, timeTool},
|
|
||||||
Think: true,
|
|
||||||
Check: MinTools(1),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var weatherTool = api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Description: "Get the current weather in a given location",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Required: []string{"location"},
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"location": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The city and state",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var timeTool = api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "get_time",
|
|
||||||
Description: "Get the current time in a timezone",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Required: []string{"timezone"},
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"timezone": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "The timezone name",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
var restaurantTool = api.Tool{
|
|
||||||
Type: "function",
|
|
||||||
Function: api.ToolFunction{
|
|
||||||
Name: "search_restaurants",
|
|
||||||
Description: "Search for restaurants",
|
|
||||||
Parameters: api.ToolFunctionParameters{
|
|
||||||
Type: "object",
|
|
||||||
Required: []string{"cuisine"},
|
|
||||||
Properties: map[string]api.ToolProperty{
|
|
||||||
"cuisine": {
|
|
||||||
Type: api.PropertyType{"string"},
|
|
||||||
Description: "Type of cuisine",
|
|
||||||
},
|
|
||||||
"min_price": {
|
|
||||||
Type: api.PropertyType{"number"},
|
|
||||||
Description: "Minimum price",
|
|
||||||
},
|
|
||||||
"max_price": {
|
|
||||||
Type: api.PropertyType{"number"},
|
|
||||||
Description: "Maximum price",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
148
cmd/testolmo/main.go
Normal file
148
cmd/testolmo/main.go
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
_ "github.com/ollama/ollama/model/models" // Register all models
|
||||||
|
"github.com/ollama/ollama/model/renderers"
|
||||||
|
"github.com/ollama/ollama/sample"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
modelPath := "/Users/parth/.ollama/models/blobs/sha256-a87e10578f328b087f888ac7bd1018555e26028a1130980f20312b4de3a10d70"
|
||||||
|
|
||||||
|
fmt.Println("Loading OLMo model...")
|
||||||
|
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("✅ Model loaded successfully!")
|
||||||
|
|
||||||
|
// Initialize the cache
|
||||||
|
cache := m.Config().Cache
|
||||||
|
if cache != nil {
|
||||||
|
// Initialize with reasonable defaults:
|
||||||
|
// - dtype: F16
|
||||||
|
// - maxSequences: 1 (single sequence)
|
||||||
|
// - capacity: 2048 (context length)
|
||||||
|
// - maxBatch: 512
|
||||||
|
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
|
||||||
|
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use the olmo3 renderer to format the prompt properly
|
||||||
|
messages := []api.Message{
|
||||||
|
{Role: "user", Content: "wagwan"},
|
||||||
|
}
|
||||||
|
// prompt := "Question: What is machine learning? Answer:"
|
||||||
|
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
// prompt = prompt[:len(prompt)]
|
||||||
|
// prompt := "Question: What is machine learning? Answer:"
|
||||||
|
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
|
||||||
|
|
||||||
|
tp := m.(model.TextProcessor)
|
||||||
|
tokens, err := tp.Encode(prompt, false)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
|
||||||
|
|
||||||
|
// Generate 20 tokens
|
||||||
|
maxTokens := 20
|
||||||
|
generated := make([]int32, 0, maxTokens)
|
||||||
|
|
||||||
|
// Create sampler (temperature=0 for greedy sampling)
|
||||||
|
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
||||||
|
|
||||||
|
for i := 0; i < maxTokens; i++ {
|
||||||
|
// Create a new context for each generation step to avoid memory buildup
|
||||||
|
ctx := m.Backend().NewContext()
|
||||||
|
|
||||||
|
var inputTokens []int32
|
||||||
|
var positions []int32
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
// First iteration: process all prompt tokens
|
||||||
|
inputTokens = tokens
|
||||||
|
positions = make([]int32, len(tokens))
|
||||||
|
for j := range positions {
|
||||||
|
positions[j] = int32(j)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Subsequent iterations: only process the newly generated token
|
||||||
|
// The last token is at position len(tokens)-1 (its index in the sequence)
|
||||||
|
inputTokens = []int32{tokens[len(tokens)-1]}
|
||||||
|
positions = []int32{int32(len(tokens) - 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
sequences := make([]int, len(inputTokens))
|
||||||
|
// All tokens belong to sequence 0
|
||||||
|
|
||||||
|
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
|
||||||
|
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
|
||||||
|
|
||||||
|
batch := input.Batch{
|
||||||
|
Inputs: inputsTensor,
|
||||||
|
Positions: positions,
|
||||||
|
Sequences: sequences,
|
||||||
|
Outputs: outputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward pass (model.Forward handles cache.StartForward internally)
|
||||||
|
logits, err := model.Forward(ctx, m, batch)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = logits.Contiguous(ctx)
|
||||||
|
ctx.Forward(logits).Compute(logits)
|
||||||
|
|
||||||
|
logitValues := logits.Floats()
|
||||||
|
|
||||||
|
// Sample next token
|
||||||
|
nextToken, err := sampler.Sample(logitValues)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close context before moving to next iteration
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
generated = append(generated, nextToken)
|
||||||
|
tokens = append(tokens, nextToken)
|
||||||
|
|
||||||
|
// Decode and print
|
||||||
|
decoded, _ := tp.Decode([]int32{nextToken})
|
||||||
|
fmt.Print(decoded)
|
||||||
|
|
||||||
|
// Stop on EOS or <|im_end|>
|
||||||
|
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Check if we generated <|im_end|> (stop token for chat)
|
||||||
|
if decoded == "<|im_end|>" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✅ Generation completed!")
|
||||||
|
fullText, _ := tp.Decode(generated)
|
||||||
|
fmt.Printf("Generated: %s\n", fullText)
|
||||||
|
}
|
||||||
@@ -200,6 +200,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
|
|||||||
conv = &qwen25VLModel{}
|
conv = &qwen25VLModel{}
|
||||||
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration":
|
||||||
conv = &qwen3VLModel{}
|
conv = &qwen3VLModel{}
|
||||||
|
case "OLMo2ForCausalLM", "Olmo2ForCausalLM", "OLMo3ForCausalLM", "Olmo3ForCausalLM":
|
||||||
|
conv = &olmoModel{}
|
||||||
case "BertModel":
|
case "BertModel":
|
||||||
conv = &bertModel{}
|
conv = &bertModel{}
|
||||||
case "CohereForCausalLM":
|
case "CohereForCausalLM":
|
||||||
|
|||||||
94
convert/convert_olmo.go
Normal file
94
convert/convert_olmo.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package convert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs/ggml"
|
||||||
|
)
|
||||||
|
|
||||||
|
type olmoModel struct {
|
||||||
|
ModelParameters
|
||||||
|
|
||||||
|
HiddenSize uint32 `json:"hidden_size"`
|
||||||
|
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||||
|
IntermediateSize uint32 `json:"intermediate_size"`
|
||||||
|
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||||
|
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||||
|
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
|
||||||
|
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||||
|
RopeTheta float32 `json:"rope_theta"`
|
||||||
|
ClampKQV float32 `json:"f_clamp_kqv"`
|
||||||
|
SlidingWindow uint32 `json:"sliding_window"`
|
||||||
|
LayerTypes []string `json:"layer_types"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ ModelConverter = (*olmoModel)(nil)
|
||||||
|
|
||||||
|
func (p *olmoModel) KV(t *Tokenizer) ggml.KV {
|
||||||
|
kv := p.ModelParameters.KV(t)
|
||||||
|
kv["general.architecture"] = "olmo"
|
||||||
|
kv["olmo.block_count"] = p.NumHiddenLayers
|
||||||
|
kv["olmo.context_length"] = p.MaxPositionEmbeddings
|
||||||
|
kv["olmo.embedding_length"] = p.HiddenSize
|
||||||
|
kv["olmo.feed_forward_length"] = p.IntermediateSize
|
||||||
|
kv["olmo.attention.head_count"] = p.NumAttentionHeads
|
||||||
|
kv["olmo.attention.head_count_kv"] = cmp.Or(p.NumKeyValueHeads, p.NumAttentionHeads)
|
||||||
|
|
||||||
|
if p.RopeTheta > 0 {
|
||||||
|
kv["olmo.rope.freq_base"] = p.RopeTheta
|
||||||
|
} else {
|
||||||
|
kv["olmo.rope.freq_base"] = float32(10000.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.RMSNormEPS > 0 {
|
||||||
|
kv["olmo.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.ClampKQV > 0 {
|
||||||
|
kv["olmo.attention.clamp_kqv"] = p.ClampKQV
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.SlidingWindow > 0 {
|
||||||
|
kv["olmo.attention.sliding_window"] = p.SlidingWindow
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(p.LayerTypes) > 0 {
|
||||||
|
kv["olmo.attention.layer_types"] = p.LayerTypes
|
||||||
|
}
|
||||||
|
|
||||||
|
return kv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||||
|
out := make([]*ggml.Tensor, 0, len(ts))
|
||||||
|
for _, t := range ts {
|
||||||
|
out = append(out, &ggml.Tensor{
|
||||||
|
Name: t.Name(),
|
||||||
|
Kind: t.Kind(),
|
||||||
|
Shape: t.Shape(),
|
||||||
|
WriterTo: t,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *olmoModel) Replacements() []string {
|
||||||
|
return []string{
|
||||||
|
"lm_head", "output",
|
||||||
|
"model.embed_tokens", "token_embd",
|
||||||
|
"model.layers", "blk",
|
||||||
|
"model.norm", "output_norm",
|
||||||
|
"self_attn.q_proj", "attn_q",
|
||||||
|
"self_attn.k_proj", "attn_k",
|
||||||
|
"self_attn.v_proj", "attn_v",
|
||||||
|
"self_attn.o_proj", "attn_output",
|
||||||
|
"self_attn.q_norm", "attn_q_norm",
|
||||||
|
"self_attn.k_norm", "attn_k_norm",
|
||||||
|
"post_attention_layernorm", "post_attention_norm",
|
||||||
|
"post_feedforward_layernorm", "post_ffw_norm",
|
||||||
|
"mlp.gate_proj", "ffn_gate",
|
||||||
|
"mlp.down_proj", "ffn_down",
|
||||||
|
"mlp.up_proj", "ffn_up",
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
|
|||||||
|
|
||||||
| Parameter | Description | Value Type | Example Usage |
|
| Parameter | Description | Value Type | Example Usage |
|
||||||
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
|
||||||
|
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
|
||||||
|
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||||
|
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||||
|
|||||||
@@ -252,6 +252,7 @@ func (kv KV) OllamaEngineRequired() bool {
|
|||||||
"deepseekocr",
|
"deepseekocr",
|
||||||
"deepseek2",
|
"deepseek2",
|
||||||
"nomic-bert",
|
"nomic-bert",
|
||||||
|
"olmo2",
|
||||||
}, kv.Architecture())
|
}, kv.Architecture())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
_ "github.com/ollama/ollama/model/models/mistral3"
|
_ "github.com/ollama/ollama/model/models/mistral3"
|
||||||
_ "github.com/ollama/ollama/model/models/mllama"
|
_ "github.com/ollama/ollama/model/models/mllama"
|
||||||
_ "github.com/ollama/ollama/model/models/nomicbert"
|
_ "github.com/ollama/ollama/model/models/nomicbert"
|
||||||
|
_ "github.com/ollama/ollama/model/models/olmo"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen2"
|
_ "github.com/ollama/ollama/model/models/qwen2"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
_ "github.com/ollama/ollama/model/models/qwen25vl"
|
||||||
_ "github.com/ollama/ollama/model/models/qwen3"
|
_ "github.com/ollama/ollama/model/models/qwen3"
|
||||||
|
|||||||
271
model/models/olmo/model.go
Normal file
271
model/models/olmo/model.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package olmo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"math"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/fs"
|
||||||
|
"github.com/ollama/ollama/kvcache"
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/ml/nn"
|
||||||
|
"github.com/ollama/ollama/ml/nn/fast"
|
||||||
|
"github.com/ollama/ollama/ml/nn/rope"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Options struct {
|
||||||
|
hiddenSize, numHeads, numKVHeads int
|
||||||
|
headDim, ropeDim int
|
||||||
|
eps, ropeBase, ropeScale float32
|
||||||
|
clampKQV float32
|
||||||
|
|
||||||
|
originalContextLength int
|
||||||
|
attnFactor float32
|
||||||
|
slidingWindow int32
|
||||||
|
slidingWindowPattern []bool // per-layer SWA pattern (true = SWA, false = full attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Model struct {
|
||||||
|
model.Base
|
||||||
|
model.TextProcessor
|
||||||
|
|
||||||
|
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
|
||||||
|
Layers []Layer `gguf:"blk"`
|
||||||
|
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
|
||||||
|
Output *nn.Linear `gguf:"output,alt:token_embd"`
|
||||||
|
|
||||||
|
Options
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(c fs.Config) (model.Model, error) {
|
||||||
|
vocabulary := model.Vocabulary{
|
||||||
|
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||||
|
Scores: c.Floats("tokenizer.ggml.scores"),
|
||||||
|
Types: c.Ints("tokenizer.ggml.token_type"),
|
||||||
|
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||||
|
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||||
|
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
|
||||||
|
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||||
|
EOS: append(
|
||||||
|
[]int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))},
|
||||||
|
c.Ints("tokenizer.ggml.eos_token_ids")...,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.String("tokenizer.ggml.model") != "gpt2" {
|
||||||
|
return nil, model.ErrUnsupportedTokenizer
|
||||||
|
}
|
||||||
|
|
||||||
|
var pretokenizers []string
|
||||||
|
if c.String("tokenizer.ggml.pre") != "default" {
|
||||||
|
pretokenizers = []string{
|
||||||
|
"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
processor := model.NewBytePairEncoding(&vocabulary, pretokenizers...)
|
||||||
|
|
||||||
|
slidingWindow := int32(c.Uint("attention.sliding_window"))
|
||||||
|
slidingWindowPattern := c.Bools("attention.sliding_window_pattern")
|
||||||
|
|
||||||
|
m := Model{
|
||||||
|
TextProcessor: processor,
|
||||||
|
Layers: make([]Layer, c.Uint("block_count")),
|
||||||
|
Options: Options{
|
||||||
|
hiddenSize: int(c.Uint("embedding_length")),
|
||||||
|
numHeads: int(c.Uint("attention.head_count")),
|
||||||
|
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||||
|
headDim: int(c.Uint("attention.key_length")),
|
||||||
|
ropeDim: int(c.Uint("rope.dimension_count")),
|
||||||
|
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||||
|
ropeBase: c.Float("rope.freq_base", 1e4),
|
||||||
|
ropeScale: c.Float("rope.scaling.factor", 1),
|
||||||
|
clampKQV: c.Float("attention.clamp_kqv", 0),
|
||||||
|
originalContextLength: int(c.Uint("rope.scaling.original_context_length")),
|
||||||
|
attnFactor: c.Float("rope.scaling.attn_factor", 1),
|
||||||
|
slidingWindow: slidingWindow,
|
||||||
|
slidingWindowPattern: slidingWindowPattern,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// OLMo3 uses interleaved sliding window attention (every 4th layer is full attention)
|
||||||
|
m.Cache = kvcache.NewWrapperCache(
|
||||||
|
kvcache.NewSWACache(slidingWindow, m.Shift),
|
||||||
|
kvcache.NewCausalCache(m.Shift),
|
||||||
|
)
|
||||||
|
|
||||||
|
return &m, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type SelfAttention struct {
|
||||||
|
Query *nn.Linear `gguf:"attn_q"`
|
||||||
|
Key *nn.Linear `gguf:"attn_k"`
|
||||||
|
Value *nn.Linear `gguf:"attn_v"`
|
||||||
|
Output *nn.Linear `gguf:"attn_output"`
|
||||||
|
QNorm *nn.RMSNorm `gguf:"attn_q_norm"`
|
||||||
|
KNorm *nn.RMSNorm `gguf:"attn_k_norm"`
|
||||||
|
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Options) {
|
||||||
|
opts := []func(*rope.Options){
|
||||||
|
rope.WithFactors(factors),
|
||||||
|
}
|
||||||
|
|
||||||
|
if o.originalContextLength > 0 {
|
||||||
|
if isSWA {
|
||||||
|
// For SWA layers, use regular rope with no YaRN scaling
|
||||||
|
// ext_factor=0.0, attn_factor=1.0 per llama.cpp
|
||||||
|
opts = append(opts,
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithExtrapolationFactor(0.),
|
||||||
|
rope.WithAttentionFactor(1.),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// For full attention layers, use YaRN scaling
|
||||||
|
opts = append(opts,
|
||||||
|
rope.WithOriginalContextLength(o.originalContextLength),
|
||||||
|
rope.WithExtrapolationFactor(1.),
|
||||||
|
rope.WithAttentionFactor(o.attnFactor),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
|
||||||
|
batchSize := hiddenState.Dim(1)
|
||||||
|
headDim := cmp.Or(opts.headDim, opts.hiddenSize/opts.numHeads)
|
||||||
|
ropeDim := cmp.Or(opts.ropeDim, headDim)
|
||||||
|
|
||||||
|
query := sa.Query.Forward(ctx, hiddenState)
|
||||||
|
if sa.QNorm != nil {
|
||||||
|
query = sa.QNorm.Forward(ctx, query, opts.eps)
|
||||||
|
}
|
||||||
|
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||||
|
|
||||||
|
key := sa.Key.Forward(ctx, hiddenState)
|
||||||
|
if sa.KNorm != nil {
|
||||||
|
key = sa.KNorm.Forward(ctx, key, opts.eps)
|
||||||
|
}
|
||||||
|
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
value := sa.Value.Forward(ctx, hiddenState)
|
||||||
|
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||||
|
|
||||||
|
freqScale := float32(1.0)
|
||||||
|
if !isSWA {
|
||||||
|
freqScale = 1. / opts.ropeScale
|
||||||
|
}
|
||||||
|
|
||||||
|
ropeOpts := opts.ropeOptions(sa.RopeFactors, isSWA)
|
||||||
|
query = fast.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
|
||||||
|
key = fast.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
|
||||||
|
|
||||||
|
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
|
||||||
|
attention = attention.Reshape(ctx, headDim*opts.numHeads, batchSize)
|
||||||
|
return sa.Output.Forward(ctx, attention)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
|
ropeDim := cmp.Or(m.ropeDim, m.hiddenSize/m.numHeads)
|
||||||
|
isSWA := m.isSWALayer(layer)
|
||||||
|
|
||||||
|
freqScale := float32(1.0)
|
||||||
|
if !isSWA {
|
||||||
|
freqScale = 1. / m.ropeScale
|
||||||
|
}
|
||||||
|
|
||||||
|
ropeOpts := m.Options.ropeOptions(m.Layers[layer].SelfAttention.RopeFactors, isSWA)
|
||||||
|
return fast.RoPE(ctx, key, shift, ropeDim, m.ropeBase, freqScale, ropeOpts...), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type MLP struct {
|
||||||
|
Up *nn.Linear `gguf:"ffn_up"`
|
||||||
|
Down *nn.Linear `gguf:"ffn_down"`
|
||||||
|
Gate *nn.Linear `gguf:"ffn_gate"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
|
||||||
|
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
|
||||||
|
return mlp.Down.Forward(ctx, hiddenState)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Layer struct {
|
||||||
|
SelfAttention *SelfAttention
|
||||||
|
PostAttentionNorm *nn.RMSNorm `gguf:"post_attention_norm"`
|
||||||
|
MLP *MLP
|
||||||
|
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
|
||||||
|
residual := hiddenState
|
||||||
|
|
||||||
|
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts, isSWA)
|
||||||
|
|
||||||
|
if outputs != nil {
|
||||||
|
hiddenState = hiddenState.Rows(ctx, outputs)
|
||||||
|
residual = residual.Rows(ctx, outputs)
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.PostAttentionNorm != nil {
|
||||||
|
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
ffnInput := hiddenState.Add(ctx, residual)
|
||||||
|
|
||||||
|
hiddenState = l.MLP.Forward(ctx, ffnInput, opts)
|
||||||
|
|
||||||
|
if l.PostFFWNorm != nil {
|
||||||
|
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps)
|
||||||
|
}
|
||||||
|
|
||||||
|
return hiddenState.Add(ctx, ffnInput)
|
||||||
|
}
|
||||||
|
|
||||||
|
// isSWALayer returns true if the layer uses sliding window attention.
|
||||||
|
// Uses the sliding_window_pattern from the model config if available,
|
||||||
|
// otherwise falls back to the default OLMo3 pattern (every 4th layer is full attention).
|
||||||
|
func (m *Model) isSWALayer(layerIdx int) bool {
|
||||||
|
if len(m.slidingWindowPattern) > layerIdx {
|
||||||
|
return m.slidingWindowPattern[layerIdx]
|
||||||
|
}
|
||||||
|
// Fallback: OLMo3 pattern where every 4th layer (indices 3, 7, 11, ...) uses full attention
|
||||||
|
return (layerIdx+1)%4 != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||||
|
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
|
||||||
|
|
||||||
|
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
|
||||||
|
|
||||||
|
for i, layer := range m.Layers {
|
||||||
|
m.Cache.SetLayer(i)
|
||||||
|
|
||||||
|
isSWA := m.isSWALayer(i)
|
||||||
|
|
||||||
|
// Set cache type for interleaved SWA (OLMo3)
|
||||||
|
if wc, ok := m.Cache.(*kvcache.WrapperCache); ok {
|
||||||
|
if isSWA {
|
||||||
|
wc.SetLayerType(0) // SWA cache
|
||||||
|
} else {
|
||||||
|
wc.SetLayerType(1) // Causal cache
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var outputs ml.Tensor
|
||||||
|
if i == len(m.Layers)-1 {
|
||||||
|
outputs = batch.Outputs
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options, isSWA)
|
||||||
|
}
|
||||||
|
|
||||||
|
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
|
||||||
|
return m.Output.Forward(ctx, hiddenState), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
model.Register("olmo2", New)
|
||||||
|
}
|
||||||
132
model/models/olmo/testolmo.go
Normal file
132
model/models/olmo/testolmo.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package olmo
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/ml"
|
||||||
|
"github.com/ollama/ollama/model"
|
||||||
|
"github.com/ollama/ollama/model/input"
|
||||||
|
"github.com/ollama/ollama/sample"
|
||||||
|
)
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
modelPath := "/Users/nicole/models/Olmo-3-7B-Think/olmo-3-7b-think-q8_0.gguf"
|
||||||
|
|
||||||
|
fmt.Println("Loading OLMo model...")
|
||||||
|
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("✅ Model loaded successfully!")
|
||||||
|
|
||||||
|
// Initialize the cache
|
||||||
|
cache := m.Config().Cache
|
||||||
|
if cache != nil {
|
||||||
|
// Initialize with reasonable defaults:
|
||||||
|
// - dtype: F16
|
||||||
|
// - maxSequences: 1 (single sequence)
|
||||||
|
// - capacity: 2048 (context length)
|
||||||
|
// - maxBatch: 512
|
||||||
|
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
|
||||||
|
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test generation
|
||||||
|
prompt := "Question: What is machine learning? Answer:"
|
||||||
|
fmt.Printf("\nPrompt: %s\n", prompt)
|
||||||
|
|
||||||
|
tp := m.(model.TextProcessor)
|
||||||
|
tokens, err := tp.Encode(prompt, true)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
|
||||||
|
|
||||||
|
// Generate 20 tokens
|
||||||
|
maxTokens := 20
|
||||||
|
generated := make([]int32, 0, maxTokens)
|
||||||
|
|
||||||
|
// Create sampler (temperature=0 for greedy sampling)
|
||||||
|
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
||||||
|
|
||||||
|
for i := 0; i < maxTokens; i++ {
|
||||||
|
// Create a new context for each generation step to avoid memory buildup
|
||||||
|
ctx := m.Backend().NewContext()
|
||||||
|
|
||||||
|
var inputTokens []int32
|
||||||
|
var positions []int32
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
// First iteration: process all prompt tokens
|
||||||
|
inputTokens = tokens
|
||||||
|
positions = make([]int32, len(tokens))
|
||||||
|
for j := range positions {
|
||||||
|
positions[j] = int32(j)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Subsequent iterations: only process the newly generated token
|
||||||
|
// The last token is at position len(tokens)-1 (its index in the sequence)
|
||||||
|
inputTokens = []int32{tokens[len(tokens)-1]}
|
||||||
|
positions = []int32{int32(len(tokens) - 1)}
|
||||||
|
}
|
||||||
|
|
||||||
|
sequences := make([]int, len(inputTokens))
|
||||||
|
// All tokens belong to sequence 0
|
||||||
|
|
||||||
|
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
|
||||||
|
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
|
||||||
|
|
||||||
|
batch := input.Batch{
|
||||||
|
Inputs: inputsTensor,
|
||||||
|
Positions: positions,
|
||||||
|
Sequences: sequences,
|
||||||
|
Outputs: outputs,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward pass (model.Forward handles cache.StartForward internally)
|
||||||
|
logits, err := model.Forward(ctx, m, batch)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logits = logits.Contiguous(ctx)
|
||||||
|
ctx.Forward(logits).Compute(logits)
|
||||||
|
|
||||||
|
logitValues := logits.Floats()
|
||||||
|
|
||||||
|
// Sample next token
|
||||||
|
nextToken, err := sampler.Sample(logitValues)
|
||||||
|
if err != nil {
|
||||||
|
ctx.Close()
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close context before moving to next iteration
|
||||||
|
ctx.Close()
|
||||||
|
|
||||||
|
generated = append(generated, nextToken)
|
||||||
|
tokens = append(tokens, nextToken)
|
||||||
|
|
||||||
|
// Decode and print
|
||||||
|
decoded, _ := tp.Decode([]int32{nextToken})
|
||||||
|
fmt.Print(decoded)
|
||||||
|
|
||||||
|
// Stop on EOS
|
||||||
|
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println("\n\n✅ Generation completed!")
|
||||||
|
fullText, _ := tp.Decode(generated)
|
||||||
|
fmt.Printf("Generated: %s\n", fullText)
|
||||||
|
}
|
||||||
@@ -1,44 +0,0 @@
|
|||||||
package parsers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
"github.com/ollama/ollama/thinking"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Intellect3Parser combines thinking support using
|
|
||||||
// the built-in thinking parser, with tool call support
|
|
||||||
// via qwen3-coder's parser.
|
|
||||||
type Intellect3Parser struct {
|
|
||||||
thinkingParser thinking.Parser
|
|
||||||
toolParser Qwen3CoderParser
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Intellect3Parser) HasToolSupport() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Intellect3Parser) HasThinkingSupport() bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Intellect3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
|
||||||
p.thinkingParser = thinking.Parser{
|
|
||||||
OpeningTag: "<think>",
|
|
||||||
ClosingTag: "</think>",
|
|
||||||
}
|
|
||||||
p.toolParser = Qwen3CoderParser{}
|
|
||||||
return p.toolParser.Init(tools, lastMessage, thinkValue)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Intellect3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
|
||||||
// First extract thinking content
|
|
||||||
thinkingContent, remainingContent := p.thinkingParser.AddContent(s)
|
|
||||||
|
|
||||||
// Then process the remaining content for tool calls
|
|
||||||
toolContent, _, toolCalls, err := p.toolParser.Add(remainingContent, done)
|
|
||||||
if err != nil {
|
|
||||||
return "", thinkingContent, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolContent, thinkingContent, toolCalls, nil
|
|
||||||
}
|
|
||||||
@@ -1,542 +0,0 @@
|
|||||||
package parsers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestIntellect3ParserThinkingOnly(t *testing.T) {
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
chunks []string
|
|
||||||
wantText string
|
|
||||||
wantThink string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "simple thinking content",
|
|
||||||
chunks: []string{"<think>I need to analyze this</think>Here is my response"},
|
|
||||||
wantText: "Here is my response",
|
|
||||||
wantThink: "I need to analyze this",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "thinking with whitespace",
|
|
||||||
chunks: []string{"<think>\n Some thoughts \n</think>\n\nContent"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Some thoughts \n", // Thinking parser preserves internal whitespace
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "thinking only",
|
|
||||||
chunks: []string{"<think>Just thinking</think>"},
|
|
||||||
wantText: "",
|
|
||||||
wantThink: "Just thinking",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "no thinking tags",
|
|
||||||
chunks: []string{"Just regular content"},
|
|
||||||
wantText: "Just regular content",
|
|
||||||
wantThink: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "streaming thinking content",
|
|
||||||
chunks: []string{"<think>Fir", "st part", " second part</think>Content"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "First part second part",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "partial opening tag",
|
|
||||||
chunks: []string{"<thi", "nk>Thinking</think>Content"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Thinking",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "partial closing tag",
|
|
||||||
chunks: []string{"<think>Thinking</thi", "nk>Content"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Thinking",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
parser.Init(nil, nil, nil)
|
|
||||||
|
|
||||||
var gotText, gotThink string
|
|
||||||
for i, chunk := range tc.chunks {
|
|
||||||
isLast := i == len(tc.chunks)-1
|
|
||||||
text, think, calls, err := parser.Add(chunk, isLast)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
gotText += text
|
|
||||||
gotThink += think
|
|
||||||
if len(calls) > 0 {
|
|
||||||
t.Fatalf("expected no tool calls, got %v", calls)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotText != tc.wantText {
|
|
||||||
t.Errorf("content: got %q, want %q", gotText, tc.wantText)
|
|
||||||
}
|
|
||||||
if gotThink != tc.wantThink {
|
|
||||||
t.Errorf("thinking: got %q, want %q", gotThink, tc.wantThink)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserToolCallsOnly(t *testing.T) {
|
|
||||||
tools := []api.Tool{
|
|
||||||
tool("get_weather", map[string]api.ToolProperty{
|
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
|
||||||
"unit": {Type: api.PropertyType{"string"}},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
chunks []string
|
|
||||||
wantText string
|
|
||||||
wantCalls []api.ToolCall
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "simple tool call",
|
|
||||||
chunks: []string{
|
|
||||||
"Let me check the weather<tool_call><function=get_weather>\n<parameter=location>\nSan Francisco\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>",
|
|
||||||
},
|
|
||||||
wantText: "Let me check the weather",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "San Francisco",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "tool call streaming",
|
|
||||||
chunks: []string{
|
|
||||||
"Checking<tool_call><function=get_wea",
|
|
||||||
"ther>\n<parameter=location>\nNew York\n</param", // nolint:all
|
|
||||||
"eter>\n<parameter=unit>\nfahrenheit\n</parameter>\n</function></tool_call>Done",
|
|
||||||
},
|
|
||||||
wantText: "CheckingDone",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "New York",
|
|
||||||
"unit": "fahrenheit",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "multiple tool calls",
|
|
||||||
chunks: []string{
|
|
||||||
"<tool_call><function=get_weather>\n<parameter=location>\nBoston\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>",
|
|
||||||
"<tool_call><function=get_weather>\n<parameter=location>\nSeattle\n</parameter>\n<parameter=unit>\nfahrenheit\n</parameter>\n</function></tool_call>",
|
|
||||||
},
|
|
||||||
wantText: "",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Boston",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Seattle",
|
|
||||||
"unit": "fahrenheit",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "no tool calls",
|
|
||||||
chunks: []string{"Just regular content"},
|
|
||||||
wantText: "Just regular content",
|
|
||||||
wantCalls: nil,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
parser.Init(tools, nil, nil)
|
|
||||||
|
|
||||||
var gotText string
|
|
||||||
var gotCalls []api.ToolCall
|
|
||||||
for i, chunk := range tc.chunks {
|
|
||||||
isLast := i == len(tc.chunks)-1
|
|
||||||
text, think, calls, err := parser.Add(chunk, isLast)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
gotText += text
|
|
||||||
gotCalls = append(gotCalls, calls...)
|
|
||||||
if think != "" {
|
|
||||||
t.Fatalf("expected no thinking, got %q", think)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotText != tc.wantText {
|
|
||||||
t.Errorf("content: got %q, want %q", gotText, tc.wantText)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(gotCalls, tc.wantCalls) {
|
|
||||||
t.Errorf("tool calls: got %#v, want %#v", gotCalls, tc.wantCalls)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserCombined(t *testing.T) {
|
|
||||||
tools := []api.Tool{
|
|
||||||
tool("get_weather", map[string]api.ToolProperty{
|
|
||||||
"location": {Type: api.PropertyType{"string"}},
|
|
||||||
"unit": {Type: api.PropertyType{"string"}},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
chunks []string
|
|
||||||
wantText string
|
|
||||||
wantThink string
|
|
||||||
wantCalls []api.ToolCall
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "thinking then tool call",
|
|
||||||
chunks: []string{
|
|
||||||
"<think>Need to get weather data</think>Let me check<tool_call><function=get_weather>\n<parameter=location>\nParis\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>",
|
|
||||||
},
|
|
||||||
wantText: "Let me check",
|
|
||||||
wantThink: "Need to get weather data",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Paris",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "thinking, tool call, and final content",
|
|
||||||
chunks: []string{
|
|
||||||
"<think>User wants weather info</think>Checking weather<tool_call><function=get_weather>\n<parameter=location>\nTokyo\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>Done!",
|
|
||||||
},
|
|
||||||
wantText: "Checking weatherDone!",
|
|
||||||
wantThink: "User wants weather info",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Tokyo",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "streaming combined content",
|
|
||||||
chunks: []string{
|
|
||||||
"<think>Analyzing",
|
|
||||||
" the request</think>",
|
|
||||||
"Let me help<tool_call>",
|
|
||||||
"<function=get_weather>\n<parameter=location>\nLondon",
|
|
||||||
"\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function>",
|
|
||||||
"</tool_call>There you go!",
|
|
||||||
},
|
|
||||||
wantText: "Let me helpThere you go!",
|
|
||||||
wantThink: "Analyzing the request",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "London",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "multiple tool calls with thinking",
|
|
||||||
chunks: []string{
|
|
||||||
"<think>Need multiple locations</think>",
|
|
||||||
"<tool_call><function=get_weather>\n<parameter=location>\nBoston\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>",
|
|
||||||
"and<tool_call><function=get_weather>\n<parameter=location>\nBerlin\n</parameter>\n<parameter=unit>\ncelsius\n</parameter>\n</function></tool_call>",
|
|
||||||
},
|
|
||||||
wantText: "and",
|
|
||||||
wantThink: "Need multiple locations",
|
|
||||||
wantCalls: []api.ToolCall{
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Boston",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Function: api.ToolCallFunction{
|
|
||||||
Name: "get_weather",
|
|
||||||
Arguments: map[string]any{
|
|
||||||
"location": "Berlin",
|
|
||||||
"unit": "celsius",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
parser.Init(tools, nil, nil)
|
|
||||||
|
|
||||||
var gotText, gotThink string
|
|
||||||
var gotCalls []api.ToolCall
|
|
||||||
for i, chunk := range tc.chunks {
|
|
||||||
isLast := i == len(tc.chunks)-1
|
|
||||||
text, think, calls, err := parser.Add(chunk, isLast)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
gotText += text
|
|
||||||
gotThink += think
|
|
||||||
gotCalls = append(gotCalls, calls...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotText != tc.wantText {
|
|
||||||
t.Errorf("content: got %q, want %q", gotText, tc.wantText)
|
|
||||||
}
|
|
||||||
if gotThink != tc.wantThink {
|
|
||||||
t.Errorf("thinking: got %q, want %q", gotThink, tc.wantThink)
|
|
||||||
}
|
|
||||||
if !reflect.DeepEqual(gotCalls, tc.wantCalls) {
|
|
||||||
t.Errorf("tool calls: got %#v, want %#v", gotCalls, tc.wantCalls)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserEdgeCases(t *testing.T) {
|
|
||||||
tools := []api.Tool{
|
|
||||||
tool("test_func", map[string]api.ToolProperty{
|
|
||||||
"param": {Type: api.PropertyType{"string"}},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
chunks []string
|
|
||||||
wantText string
|
|
||||||
wantThink string
|
|
||||||
wantCalls int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "empty input",
|
|
||||||
chunks: []string{""},
|
|
||||||
wantText: "",
|
|
||||||
wantThink: "",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "only whitespace",
|
|
||||||
chunks: []string{" \n \t "},
|
|
||||||
wantText: "",
|
|
||||||
wantThink: "",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "unclosed thinking tag",
|
|
||||||
chunks: []string{"<think>Never closes"},
|
|
||||||
wantText: "",
|
|
||||||
wantThink: "Never closes",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "unclosed tool call tag",
|
|
||||||
chunks: []string{"<tool_call><function=test_func>\n<parameter=param>\nvalue\n</parameter>\n</function>"},
|
|
||||||
wantText: "", // Qwen3CoderParser waits for closing tag, doesn't emit partial tool calls
|
|
||||||
wantThink: "",
|
|
||||||
wantCalls: 0, // Won't be parsed until </tool_call> is seen
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "unicode in thinking",
|
|
||||||
chunks: []string{"<think>思考中 🤔</think>答案是 42"},
|
|
||||||
wantText: "答案是 42",
|
|
||||||
wantThink: "思考中 🤔",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "fake thinking tag",
|
|
||||||
chunks: []string{"<thinking>This is not the right tag</thinking>Content"},
|
|
||||||
wantText: "<thinking>This is not the right tag</thinking>Content",
|
|
||||||
wantThink: "",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "fake tool call tag",
|
|
||||||
chunks: []string{"<tool>Not a tool call</tool>"},
|
|
||||||
wantText: "<tool>Not a tool call</tool>",
|
|
||||||
wantThink: "",
|
|
||||||
wantCalls: 0,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
parser.Init(tools, nil, nil)
|
|
||||||
|
|
||||||
var gotText, gotThink string
|
|
||||||
var gotCalls []api.ToolCall
|
|
||||||
for i, chunk := range tc.chunks {
|
|
||||||
isLast := i == len(tc.chunks)-1
|
|
||||||
text, think, calls, err := parser.Add(chunk, isLast)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
gotText += text
|
|
||||||
gotThink += think
|
|
||||||
gotCalls = append(gotCalls, calls...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotText != tc.wantText {
|
|
||||||
t.Errorf("content: got %q, want %q", gotText, tc.wantText)
|
|
||||||
}
|
|
||||||
if gotThink != tc.wantThink {
|
|
||||||
t.Errorf("thinking: got %q, want %q", gotThink, tc.wantThink)
|
|
||||||
}
|
|
||||||
if len(gotCalls) != tc.wantCalls {
|
|
||||||
t.Errorf("tool calls count: got %d, want %d", len(gotCalls), tc.wantCalls)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserCapabilities(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
|
|
||||||
if !parser.HasToolSupport() {
|
|
||||||
t.Error("Intellect3Parser should have tool support")
|
|
||||||
}
|
|
||||||
|
|
||||||
if !parser.HasThinkingSupport() {
|
|
||||||
t.Error("Intellect3Parser should have thinking support")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserInit(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
|
|
||||||
tools := []api.Tool{
|
|
||||||
tool("test", map[string]api.ToolProperty{
|
|
||||||
"param": {Type: api.PropertyType{"string"}},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
returnedTools := parser.Init(tools, nil, nil)
|
|
||||||
|
|
||||||
// Should return tools unchanged (delegated to Qwen3CoderParser)
|
|
||||||
if !reflect.DeepEqual(returnedTools, tools) {
|
|
||||||
t.Errorf("Init should return tools unchanged")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIntellect3ParserWhitespaceHandling(t *testing.T) {
|
|
||||||
tools := []api.Tool{
|
|
||||||
tool("test", map[string]api.ToolProperty{
|
|
||||||
"param": {Type: api.PropertyType{"string"}},
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
desc string
|
|
||||||
chunks []string
|
|
||||||
wantText string
|
|
||||||
wantThink string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
desc: "whitespace between thinking and content",
|
|
||||||
chunks: []string{"<think>Thinking</think>\n\n\nContent"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Thinking",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "whitespace inside thinking tags",
|
|
||||||
chunks: []string{"<think> \n Thinking \n </think>Content"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Thinking \n ", // Thinking parser preserves internal whitespace
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "leading whitespace before thinking",
|
|
||||||
chunks: []string{" <think>Thinking</think>Content"},
|
|
||||||
wantText: "Content",
|
|
||||||
wantThink: "Thinking",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "whitespace before tool call",
|
|
||||||
chunks: []string{"Text <tool_call><function=test>\n<parameter=param>\nvalue\n</parameter>\n</function></tool_call>"},
|
|
||||||
wantText: "Text",
|
|
||||||
wantThink: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
desc: "whitespace after tool call",
|
|
||||||
chunks: []string{"<tool_call><function=test>\n<parameter=param>\nvalue\n</parameter>\n</function></tool_call> Text"},
|
|
||||||
wantText: "Text",
|
|
||||||
wantThink: "",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.desc, func(t *testing.T) {
|
|
||||||
parser := Intellect3Parser{}
|
|
||||||
parser.Init(tools, nil, nil)
|
|
||||||
|
|
||||||
var gotText, gotThink string
|
|
||||||
for i, chunk := range tc.chunks {
|
|
||||||
isLast := i == len(tc.chunks)-1
|
|
||||||
text, think, _, err := parser.Add(chunk, isLast)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("unexpected error: %v", err)
|
|
||||||
}
|
|
||||||
gotText += text
|
|
||||||
gotThink += think
|
|
||||||
}
|
|
||||||
|
|
||||||
if gotText != tc.wantText {
|
|
||||||
t.Errorf("content: got %q, want %q", gotText, tc.wantText)
|
|
||||||
}
|
|
||||||
if gotThink != tc.wantThink {
|
|
||||||
t.Errorf("thinking: got %q, want %q", gotThink, tc.wantThink)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
469
model/parsers/olmo3.go
Normal file
469
model/parsers/olmo3.go
Normal file
@@ -0,0 +1,469 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
"github.com/ollama/ollama/logutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type olmo3ParserState int
|
||||||
|
|
||||||
|
const (
|
||||||
|
olmo3StateContent olmo3ParserState = iota
|
||||||
|
olmo3StateToolCalls
|
||||||
|
olmo3StateToolCallsDone
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
olmo3FuncCallsOpenTag = "<function_calls>"
|
||||||
|
olmo3FuncCallsCloseTag = "</function_calls>"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Olmo3Parser struct {
|
||||||
|
state olmo3ParserState
|
||||||
|
buffer strings.Builder
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) HasToolSupport() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) HasThinkingSupport() bool {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
|
||||||
|
p.state = olmo3StateContent
|
||||||
|
return tools
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmo3ParserEvent interface {
|
||||||
|
isOlmo3ParserEvent()
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmo3ParserEventContent struct {
|
||||||
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
type olmo3ParserEventToolCalls struct {
|
||||||
|
calls []api.ToolCall
|
||||||
|
}
|
||||||
|
|
||||||
|
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
|
||||||
|
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
|
|
||||||
|
if done {
|
||||||
|
// Drain any remaining content
|
||||||
|
bufStr := p.buffer.String()
|
||||||
|
p.buffer.Reset()
|
||||||
|
if p.state == olmo3StateContent && len(bufStr) > 0 {
|
||||||
|
return bufStr, "", nil, nil
|
||||||
|
}
|
||||||
|
return "", "", nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
events := p.parseEvents()
|
||||||
|
|
||||||
|
var contentSb strings.Builder
|
||||||
|
var allCalls []api.ToolCall
|
||||||
|
for _, event := range events {
|
||||||
|
switch event := event.(type) {
|
||||||
|
case olmo3ParserEventContent:
|
||||||
|
contentSb.WriteString(event.content)
|
||||||
|
case olmo3ParserEventToolCalls:
|
||||||
|
allCalls = append(allCalls, event.calls...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return contentSb.String(), "", allCalls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
|
||||||
|
var all []olmo3ParserEvent
|
||||||
|
|
||||||
|
keepLooping := true
|
||||||
|
for keepLooping {
|
||||||
|
var events []olmo3ParserEvent
|
||||||
|
events, keepLooping = p.eat()
|
||||||
|
if len(events) > 0 {
|
||||||
|
all = append(all, events...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(all) > 0 {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return all
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
|
||||||
|
var events []olmo3ParserEvent
|
||||||
|
bufStr := p.buffer.String()
|
||||||
|
if bufStr == "" {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p.state {
|
||||||
|
case olmo3StateContent:
|
||||||
|
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
|
||||||
|
// Found <function_calls> tag
|
||||||
|
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
|
||||||
|
content := split[0]
|
||||||
|
remaining := split[1]
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = olmo3StateToolCalls
|
||||||
|
|
||||||
|
if len(content) > 0 {
|
||||||
|
events = append(events, olmo3ParserEventContent{content: content})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
|
||||||
|
// Partial <function_calls> tag - withhold ambiguous content
|
||||||
|
unambiguous := bufStr[:len(bufStr)-overlapLen]
|
||||||
|
ambiguous := bufStr[len(bufStr)-overlapLen:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, olmo3ParserEventContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
} else {
|
||||||
|
// Regular content - emit all
|
||||||
|
p.buffer.Reset()
|
||||||
|
if len(bufStr) > 0 {
|
||||||
|
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
case olmo3StateToolCalls:
|
||||||
|
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
|
||||||
|
// Found </function_calls> tag
|
||||||
|
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
|
||||||
|
toolCallsStr := split[0]
|
||||||
|
remaining := split[1]
|
||||||
|
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(remaining)
|
||||||
|
p.state = olmo3StateToolCallsDone
|
||||||
|
|
||||||
|
// Parse the function calls
|
||||||
|
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
|
||||||
|
if err != nil {
|
||||||
|
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
|
||||||
|
} else if len(calls) > 0 {
|
||||||
|
events = append(events, olmo3ParserEventToolCalls{calls: calls})
|
||||||
|
}
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
|
||||||
|
// Partial </function_calls> tag - wait for more
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
// Still collecting tool calls, wait for close tag
|
||||||
|
return events, false
|
||||||
|
|
||||||
|
case olmo3StateToolCallsDone:
|
||||||
|
// After tool calls, emit remaining content
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.state = olmo3StateContent
|
||||||
|
if len(bufStr) > 0 {
|
||||||
|
events = append(events, olmo3ParserEventContent{content: bufStr})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
|
||||||
|
// func_name(arg1="value1", arg2=123)
|
||||||
|
// Multiple calls are separated by newlines
|
||||||
|
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
|
||||||
|
var calls []api.ToolCall
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split by newlines for multiple function calls
|
||||||
|
lines := strings.Split(s, "\n")
|
||||||
|
for _, line := range lines {
|
||||||
|
line = strings.TrimSpace(line)
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
call, err := parseOlmo3SingleFunctionCall(line)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
|
||||||
|
}
|
||||||
|
calls = append(calls, call)
|
||||||
|
}
|
||||||
|
|
||||||
|
return calls, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Regex to match function call: func_name(args)
|
||||||
|
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
|
||||||
|
|
||||||
|
// Regex to match a single argument: key=value
|
||||||
|
// Value can be: "string", 'string', number, true, false, null, or nested structures
|
||||||
|
var argRegex = regexp.MustCompile(`^(\w+)=(.+)$`)
|
||||||
|
|
||||||
|
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
|
||||||
|
matches := funcCallRegex.FindStringSubmatch(s)
|
||||||
|
if matches == nil {
|
||||||
|
return api.ToolCall{}, fmt.Errorf("invalid function call format")
|
||||||
|
}
|
||||||
|
|
||||||
|
funcName := matches[1]
|
||||||
|
argsStr := matches[2]
|
||||||
|
|
||||||
|
args, err := parseOlmo3Arguments(argsStr)
|
||||||
|
if err != nil {
|
||||||
|
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: funcName,
|
||||||
|
Arguments: args,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOlmo3Arguments parses comma-separated key=value pairs
|
||||||
|
// Handles nested parentheses, brackets, braces, and quoted strings
|
||||||
|
func parseOlmo3Arguments(s string) (map[string]any, error) {
|
||||||
|
args := make(map[string]any)
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Split by commas, but respect nested structures and quotes
|
||||||
|
parts := splitArguments(s)
|
||||||
|
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the first = sign
|
||||||
|
eqIdx := strings.Index(part, "=")
|
||||||
|
if eqIdx == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid argument format: %s", part)
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.TrimSpace(part[:eqIdx])
|
||||||
|
valueStr := strings.TrimSpace(part[eqIdx+1:])
|
||||||
|
|
||||||
|
value, err := parseOlmo3Value(valueStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
args[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
return args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// splitArguments splits arguments by commas, respecting quotes and nested structures
|
||||||
|
func splitArguments(s string) []string {
|
||||||
|
var parts []string
|
||||||
|
var current strings.Builder
|
||||||
|
depth := 0
|
||||||
|
inString := false
|
||||||
|
stringChar := byte(0)
|
||||||
|
escaped := false
|
||||||
|
|
||||||
|
for i := 0; i < len(s); i++ {
|
||||||
|
c := s[i]
|
||||||
|
|
||||||
|
if escaped {
|
||||||
|
current.WriteByte(c)
|
||||||
|
escaped = false
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == '\\' && inString {
|
||||||
|
current.WriteByte(c)
|
||||||
|
escaped = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if (c == '"' || c == '\'') && !inString {
|
||||||
|
inString = true
|
||||||
|
stringChar = c
|
||||||
|
current.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if c == stringChar && inString {
|
||||||
|
inString = false
|
||||||
|
stringChar = 0
|
||||||
|
current.WriteByte(c)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !inString {
|
||||||
|
switch c {
|
||||||
|
case '(', '[', '{':
|
||||||
|
depth++
|
||||||
|
current.WriteByte(c)
|
||||||
|
case ')', ']', '}':
|
||||||
|
depth--
|
||||||
|
current.WriteByte(c)
|
||||||
|
case ',':
|
||||||
|
if depth == 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
current.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
current.WriteByte(c)
|
||||||
|
default:
|
||||||
|
current.WriteByte(c)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.WriteByte(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if current.Len() > 0 {
|
||||||
|
parts = append(parts, current.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return parts
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
|
||||||
|
func parseOlmo3Value(s string) (any, error) {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
|
||||||
|
// Check for quoted string
|
||||||
|
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
|
||||||
|
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
|
||||||
|
// Remove quotes and unescape
|
||||||
|
inner := s[1 : len(s)-1]
|
||||||
|
return unescapeString(inner), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for boolean
|
||||||
|
if s == "true" || s == "True" {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
if s == "false" || s == "False" {
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for null/None
|
||||||
|
if s == "null" || s == "None" || s == "nil" {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for number
|
||||||
|
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
if f, err := strconv.ParseFloat(s, 64); err == nil {
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for array [...]
|
||||||
|
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
|
||||||
|
return parseOlmo3Array(s[1 : len(s)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for object {...}
|
||||||
|
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
|
||||||
|
return parseOlmo3Object(s[1 : len(s)-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default to string without quotes
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOlmo3Array(s string) ([]any, error) {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return []any{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := splitArguments(s)
|
||||||
|
var arr []any
|
||||||
|
for _, part := range parts {
|
||||||
|
val, err := parseOlmo3Value(part)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
arr = append(arr, val)
|
||||||
|
}
|
||||||
|
return arr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseOlmo3Object(s string) (map[string]any, error) {
|
||||||
|
s = strings.TrimSpace(s)
|
||||||
|
if s == "" {
|
||||||
|
return map[string]any{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Objects use key: value or "key": value format
|
||||||
|
obj := make(map[string]any)
|
||||||
|
parts := splitArguments(s)
|
||||||
|
for _, part := range parts {
|
||||||
|
part = strings.TrimSpace(part)
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find colon separator
|
||||||
|
colonIdx := strings.Index(part, ":")
|
||||||
|
if colonIdx == -1 {
|
||||||
|
return nil, fmt.Errorf("invalid object entry: %s", part)
|
||||||
|
}
|
||||||
|
|
||||||
|
keyStr := strings.TrimSpace(part[:colonIdx])
|
||||||
|
valueStr := strings.TrimSpace(part[colonIdx+1:])
|
||||||
|
|
||||||
|
// Remove quotes from key if present
|
||||||
|
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
|
||||||
|
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
|
||||||
|
keyStr = keyStr[1 : len(keyStr)-1]
|
||||||
|
}
|
||||||
|
|
||||||
|
val, err := parseOlmo3Value(valueStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
obj[keyStr] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
return obj, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func unescapeString(s string) string {
|
||||||
|
// Handle common escape sequences
|
||||||
|
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
|
||||||
|
s = strings.ReplaceAll(s, `\"`, `"`)
|
||||||
|
s = strings.ReplaceAll(s, `\'`, `'`)
|
||||||
|
s = strings.ReplaceAll(s, `\n`, "\n")
|
||||||
|
s = strings.ReplaceAll(s, `\t`, "\t")
|
||||||
|
s = strings.ReplaceAll(s, `\r`, "\r")
|
||||||
|
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
|
||||||
|
return s
|
||||||
|
}
|
||||||
483
model/parsers/olmo3_test.go
Normal file
483
model/parsers/olmo3_test.go
Normal file
@@ -0,0 +1,483 @@
|
|||||||
|
package parsers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOlmo3Parser(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expectedContent string
|
||||||
|
expectedThinking string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple content",
|
||||||
|
input: "Hello, how can I help you?",
|
||||||
|
expectedContent: "Hello, how can I help you?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple tool call",
|
||||||
|
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "San Francisco"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "content then tool call",
|
||||||
|
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
|
||||||
|
expectedContent: "Let me check the weather.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "NYC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple arguments",
|
||||||
|
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "book_flight",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"from": "SFO",
|
||||||
|
"to": "NYC",
|
||||||
|
"date": "2024-01-15",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple tool calls",
|
||||||
|
input: `<function_calls>get_weather(location="San Francisco")
|
||||||
|
get_weather(location="New York")</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "San Francisco"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "New York"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with numeric argument",
|
||||||
|
input: `<function_calls>set_temperature(value=72)</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "set_temperature",
|
||||||
|
Arguments: map[string]any{"value": int64(72)},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with float argument",
|
||||||
|
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "set_price",
|
||||||
|
Arguments: map[string]any{"amount": 19.99},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with boolean argument",
|
||||||
|
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "toggle_setting",
|
||||||
|
Arguments: map[string]any{"enabled": true},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with null argument",
|
||||||
|
input: `<function_calls>clear_value(field=null)</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "clear_value",
|
||||||
|
Arguments: map[string]any{"field": nil},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with array argument",
|
||||||
|
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "process_items",
|
||||||
|
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with dict argument",
|
||||||
|
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "update_config",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"settings": map[string]any{
|
||||||
|
"theme": "dark",
|
||||||
|
"fontSize": int64(14),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with nested dict",
|
||||||
|
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "create_request",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"data": map[string]any{
|
||||||
|
"user": map[string]any{
|
||||||
|
"name": "John",
|
||||||
|
"age": int64(30),
|
||||||
|
},
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with no arguments",
|
||||||
|
input: `<function_calls>get_current_time()</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_current_time",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with single quotes",
|
||||||
|
input: `<function_calls>search(query='hello world')</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "search",
|
||||||
|
Arguments: map[string]any{"query": "hello world"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with escaped quotes",
|
||||||
|
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "search",
|
||||||
|
Arguments: map[string]any{"query": `say "hello"`},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with mixed argument types",
|
||||||
|
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "create_user",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"name": "John",
|
||||||
|
"age": int64(30),
|
||||||
|
"active": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Olmo3Parser{}
|
||||||
|
p.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
content, thinking, calls, err := p.Add(tt.input, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain remaining content
|
||||||
|
finalContent, finalThinking, finalCalls, err := p.Add("", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on done: %v", err)
|
||||||
|
}
|
||||||
|
content += finalContent
|
||||||
|
thinking += finalThinking
|
||||||
|
calls = append(calls, finalCalls...)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
|
||||||
|
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
|
||||||
|
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOlmo3Parser_Streaming(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
chunks []string
|
||||||
|
expectedContent string
|
||||||
|
expectedCalls []api.ToolCall
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "streaming content",
|
||||||
|
chunks: []string{"Hello, ", "how ", "can I help?"},
|
||||||
|
expectedContent: "Hello, how can I help?",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming tool call",
|
||||||
|
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "streaming content then tool call",
|
||||||
|
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
|
||||||
|
expectedContent: "Let me check.",
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "NYC"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call tag split across chunks",
|
||||||
|
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
|
||||||
|
expectedCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "test",
|
||||||
|
Arguments: map[string]any{},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p := &Olmo3Parser{}
|
||||||
|
p.Init(nil, nil, nil)
|
||||||
|
|
||||||
|
var allContent string
|
||||||
|
var allCalls []api.ToolCall
|
||||||
|
|
||||||
|
for _, chunk := range tt.chunks {
|
||||||
|
content, _, calls, err := p.Add(chunk, false)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
allContent += content
|
||||||
|
allCalls = append(allCalls, calls...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drain
|
||||||
|
content, _, calls, err := p.Add("", true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error on done: %v", err)
|
||||||
|
}
|
||||||
|
allContent += content
|
||||||
|
allCalls = append(allCalls, calls...)
|
||||||
|
|
||||||
|
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
|
||||||
|
t.Errorf("content mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
|
||||||
|
p := &Olmo3Parser{}
|
||||||
|
if !p.HasToolSupport() {
|
||||||
|
t.Error("expected HasToolSupport to return true")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
|
||||||
|
p := &Olmo3Parser{}
|
||||||
|
if p.HasThinkingSupport() {
|
||||||
|
t.Error("expected HasThinkingSupport to return false")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOlmo3FunctionCalls(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected []api.ToolCall
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple call",
|
||||||
|
input: `get_weather(location="SF")`,
|
||||||
|
expected: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple args",
|
||||||
|
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
|
||||||
|
expected: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "send_email",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"to": "user@example.com",
|
||||||
|
"subject": "Hello",
|
||||||
|
"body": "Test message",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple calls with newlines",
|
||||||
|
input: `get_weather(location="SF")
|
||||||
|
get_time(timezone="PST")`,
|
||||||
|
expected: []api.ToolCall{
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "SF"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_time",
|
||||||
|
Arguments: map[string]any{"timezone": "PST"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty input",
|
||||||
|
input: "",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "whitespace only",
|
||||||
|
input: " \n ",
|
||||||
|
expected: nil,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
calls, err := parseOlmo3FunctionCalls(tt.input)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(calls, tt.expected); diff != "" {
|
||||||
|
t.Errorf("calls mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestParseOlmo3Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
expected any
|
||||||
|
}{
|
||||||
|
{"string double quotes", `"hello"`, "hello"},
|
||||||
|
{"string single quotes", `'hello'`, "hello"},
|
||||||
|
{"integer", "42", int64(42)},
|
||||||
|
{"negative integer", "-10", int64(-10)},
|
||||||
|
{"float", "3.14", 3.14},
|
||||||
|
{"boolean true", "true", true},
|
||||||
|
{"boolean True", "True", true},
|
||||||
|
{"boolean false", "false", false},
|
||||||
|
{"null", "null", nil},
|
||||||
|
{"None", "None", nil},
|
||||||
|
{"empty array", "[]", []any{}},
|
||||||
|
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
|
||||||
|
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
|
||||||
|
{"empty object", "{}", map[string]any{}},
|
||||||
|
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
|
||||||
|
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
|
||||||
|
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := parseOlmo3Value(tt.input)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(result, tt.expected); diff != "" {
|
||||||
|
t.Errorf("value mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -54,8 +54,8 @@ func ParserForName(name string) Parser {
|
|||||||
return harmony.NewHarmonyMessageHandler()
|
return harmony.NewHarmonyMessageHandler()
|
||||||
case "cogito":
|
case "cogito":
|
||||||
return &CogitoParser{}
|
return &CogitoParser{}
|
||||||
case "intellect-3":
|
case "olmo3":
|
||||||
return &Intellect3Parser{}
|
return &Olmo3Parser{}
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
package renderers
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Intellect3Renderer struct{}
|
|
||||||
|
|
||||||
func (r *Intellect3Renderer) Render(messages []api.Message, tools []api.Tool, think *api.ThinkValue) (string, error) {
|
|
||||||
var sb strings.Builder
|
|
||||||
|
|
||||||
// filter out system messages and choose the first (if any) to win
|
|
||||||
var systemMessage string
|
|
||||||
var filteredMessages []api.Message
|
|
||||||
for _, message := range messages {
|
|
||||||
if message.Role != "system" {
|
|
||||||
filteredMessages = append(filteredMessages, message)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if systemMessage == "" {
|
|
||||||
systemMessage = message.Content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if systemMessage != "" || len(tools) > 0 {
|
|
||||||
sb.WriteString(imStartTag + "system\n")
|
|
||||||
|
|
||||||
sb.WriteString(systemMessage)
|
|
||||||
|
|
||||||
if len(tools) > 0 {
|
|
||||||
sb.WriteString("\n\n# Tools\n\nYou have access to the following functions:\n\n")
|
|
||||||
sb.WriteString("<tools>")
|
|
||||||
for _, tool := range tools {
|
|
||||||
sb.WriteString("\n")
|
|
||||||
sb.WriteString("<function>\n")
|
|
||||||
sb.WriteString("<name>" + tool.Function.Name + "</name>")
|
|
||||||
if tool.Function.Description != "" {
|
|
||||||
sb.WriteString("\n<description>" + tool.Function.Description + "</description>")
|
|
||||||
}
|
|
||||||
sb.WriteString("\n<parameters>")
|
|
||||||
|
|
||||||
for name, prop := range tool.Function.Parameters.Properties {
|
|
||||||
sb.WriteString("\n<parameter>")
|
|
||||||
sb.WriteString("\n<name>" + name + "</name>")
|
|
||||||
|
|
||||||
if len(prop.Type) > 0 {
|
|
||||||
sb.WriteString("\n<type>" + formatToolDefinitionType(prop.Type) + "</type>")
|
|
||||||
}
|
|
||||||
|
|
||||||
if prop.Description != "" {
|
|
||||||
sb.WriteString("\n<description>" + prop.Description + "</description>")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render any additional keys not already handled
|
|
||||||
handledKeys := map[string]bool{
|
|
||||||
"type": true,
|
|
||||||
"description": true,
|
|
||||||
}
|
|
||||||
sb.WriteString(renderAdditionalKeys(prop, handledKeys))
|
|
||||||
|
|
||||||
sb.WriteString("\n</parameter>")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Render extra keys for parameters (everything except 'type' and 'properties')
|
|
||||||
paramHandledKeys := map[string]bool{
|
|
||||||
"type": true,
|
|
||||||
"properties": true,
|
|
||||||
}
|
|
||||||
sb.WriteString(renderAdditionalKeys(tool.Function.Parameters, paramHandledKeys))
|
|
||||||
|
|
||||||
sb.WriteString("\n</parameters>")
|
|
||||||
sb.WriteString("\n</function>")
|
|
||||||
}
|
|
||||||
sb.WriteString("\n</tools>")
|
|
||||||
sb.WriteString("\n\nIf you choose to call a function ONLY reply in the following format with NO suffix:\n\n<tool_call>\n<function=example_function_name>\n<parameter=example_parameter_1>\nvalue_1\n</parameter>\n<parameter=example_parameter_2>\nThis is the value for the second parameter\nthat can span\nmultiple lines\n</parameter>\n</function>\n</tool_call>\n\n<IMPORTANT>\nReminder:\n- Function calls MUST follow the specified format: an inner <function=...></function> block must be nested within <tool_call></tool_call> XML tags\n- Required parameters MUST be specified\n- You may provide optional reasoning for your function call in natural language BEFORE the function call, but NOT after\n- If there is no function call available, answer the question like normal with your current knowledge and do not tell the user about function calls\n</IMPORTANT>")
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString(imEndTag + "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, message := range filteredMessages {
|
|
||||||
lastMessage := i == len(filteredMessages)-1
|
|
||||||
prefill := lastMessage && message.Role == "assistant"
|
|
||||||
switch message.Role {
|
|
||||||
case "assistant":
|
|
||||||
if len(message.ToolCalls) > 0 {
|
|
||||||
sb.WriteString(imStartTag + "assistant")
|
|
||||||
|
|
||||||
// Add thinking tags if present
|
|
||||||
if message.Thinking != "" {
|
|
||||||
sb.WriteString("\n<think>" + strings.TrimSpace(message.Thinking) + "</think>")
|
|
||||||
}
|
|
||||||
|
|
||||||
if message.Content != "" {
|
|
||||||
sb.WriteString("\n" + strings.TrimSpace(message.Content) + "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, toolCall := range message.ToolCalls {
|
|
||||||
sb.WriteString("\n<tool_call>\n<function=" + toolCall.Function.Name + ">")
|
|
||||||
for name, value := range toolCall.Function.Arguments {
|
|
||||||
valueStr := formatToolCallArgument(value)
|
|
||||||
sb.WriteString("\n<parameter=" + name + ">\n" + valueStr + "\n</parameter>")
|
|
||||||
}
|
|
||||||
sb.WriteString("\n</function>\n</tool_call>")
|
|
||||||
}
|
|
||||||
sb.WriteString("<|im_end|>\n")
|
|
||||||
} else {
|
|
||||||
sb.WriteString(imStartTag + "assistant")
|
|
||||||
|
|
||||||
// Add thinking tags if present
|
|
||||||
if message.Thinking != "" {
|
|
||||||
sb.WriteString("\n<think>" + strings.TrimSpace(message.Thinking) + "</think>")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add content if present
|
|
||||||
if message.Content != "" {
|
|
||||||
if message.Thinking != "" {
|
|
||||||
sb.WriteString("\n" + strings.TrimSpace(message.Content))
|
|
||||||
} else {
|
|
||||||
sb.WriteString("\n" + message.Content)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !prefill {
|
|
||||||
sb.WriteString(imEndTag + "\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "tool":
|
|
||||||
// consecutive tool responses should share a single `<im_start>user`, but
|
|
||||||
// have their own <tool_response> tags
|
|
||||||
|
|
||||||
// only start a new user block if this is the first tool response
|
|
||||||
if i == 0 || filteredMessages[i-1].Role != "tool" {
|
|
||||||
sb.WriteString(imStartTag + "user\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
sb.WriteString("<tool_response>\n")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
sb.WriteString("\n</tool_response>\n")
|
|
||||||
|
|
||||||
// close the user block only if this is the last tool response
|
|
||||||
if i == len(filteredMessages)-1 || filteredMessages[i+1].Role != "tool" {
|
|
||||||
sb.WriteString(imEndTag + "\n")
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
sb.WriteString(imStartTag + message.Role + "\n")
|
|
||||||
sb.WriteString(message.Content)
|
|
||||||
sb.WriteString(imEndTag + "\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
if lastMessage && !prefill {
|
|
||||||
sb.WriteString(imStartTag + "assistant\n<think>")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return sb.String(), nil
|
|
||||||
}
|
|
||||||
147
model/renderers/olmo3.go
Normal file
147
model/renderers/olmo3.go
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
|
||||||
|
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
|
||||||
|
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
|
||||||
|
)
|
||||||
|
|
||||||
|
type Olmo3Renderer struct{}
|
||||||
|
|
||||||
|
func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.ThinkValue) (string, error) {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
var systemMessage *api.Message
|
||||||
|
filteredMessages := make([]api.Message, 0, len(messages))
|
||||||
|
for i, message := range messages {
|
||||||
|
if message.Role == "system" {
|
||||||
|
if systemMessage == nil {
|
||||||
|
systemMessage = &messages[i]
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
filteredMessages = append(filteredMessages, message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Render system message
|
||||||
|
if systemMessage != nil {
|
||||||
|
// Custom system message - single newline after "system"
|
||||||
|
sb.WriteString("<|im_start|>system\n")
|
||||||
|
sb.WriteString(systemMessage.Content)
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
functionsJSON, err := marshalWithSpaces(tools)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb.WriteString("<functions>")
|
||||||
|
sb.WriteString(string(functionsJSON))
|
||||||
|
sb.WriteString("</functions>")
|
||||||
|
}
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
} else {
|
||||||
|
// Default system message - single newline after "system"
|
||||||
|
sb.WriteString("<|im_start|>system\n")
|
||||||
|
sb.WriteString(olmo3DefaultSystemMessage)
|
||||||
|
|
||||||
|
if len(tools) > 0 {
|
||||||
|
functionsJSON, err := marshalWithSpaces(tools)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb.WriteString(olmo3WithFunctionsMessage)
|
||||||
|
sb.WriteString("<functions>")
|
||||||
|
sb.WriteString(string(functionsJSON))
|
||||||
|
sb.WriteString("</functions>")
|
||||||
|
} else {
|
||||||
|
sb.WriteString(olmo3NoFunctionsMessage)
|
||||||
|
sb.WriteString("<functions></functions>")
|
||||||
|
}
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, message := range filteredMessages {
|
||||||
|
lastMessage := i == len(filteredMessages)-1
|
||||||
|
|
||||||
|
switch message.Role {
|
||||||
|
case "user":
|
||||||
|
sb.WriteString("<|im_start|>user\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
|
||||||
|
case "assistant":
|
||||||
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
|
|
||||||
|
if message.Content != "" {
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(message.ToolCalls) > 0 {
|
||||||
|
sb.WriteString("<function_calls>")
|
||||||
|
for j, tc := range message.ToolCalls {
|
||||||
|
// Format as function_name(arg1="value1", arg2="value2")
|
||||||
|
sb.WriteString(tc.Function.Name)
|
||||||
|
sb.WriteString("(")
|
||||||
|
|
||||||
|
// Get sorted keys for deterministic output
|
||||||
|
keys := make([]string, 0, len(tc.Function.Arguments))
|
||||||
|
for k := range tc.Function.Arguments {
|
||||||
|
keys = append(keys, k)
|
||||||
|
}
|
||||||
|
sort.Strings(keys)
|
||||||
|
|
||||||
|
for k, key := range keys {
|
||||||
|
if k > 0 {
|
||||||
|
sb.WriteString(", ")
|
||||||
|
}
|
||||||
|
value, err := json.Marshal(tc.Function.Arguments[key])
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf("%s=%s", key, string(value)))
|
||||||
|
}
|
||||||
|
sb.WriteString(")")
|
||||||
|
|
||||||
|
if j < len(message.ToolCalls)-1 {
|
||||||
|
sb.WriteString("\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sb.WriteString("</function_calls>")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add end tag unless it's the last message with content only (prefill)
|
||||||
|
if !lastMessage || len(message.ToolCalls) > 0 {
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
case "tool":
|
||||||
|
sb.WriteString("<|im_start|>environment\n")
|
||||||
|
sb.WriteString(message.Content)
|
||||||
|
sb.WriteString("<|im_end|>\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add generation prompt if needed
|
||||||
|
needsGenerationPrompt := true
|
||||||
|
if len(filteredMessages) > 0 {
|
||||||
|
lastMsg := filteredMessages[len(filteredMessages)-1]
|
||||||
|
if lastMsg.Role == "assistant" && len(lastMsg.ToolCalls) == 0 && lastMsg.Content != "" {
|
||||||
|
needsGenerationPrompt = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if needsGenerationPrompt {
|
||||||
|
sb.WriteString("<|im_start|>assistant\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String(), nil
|
||||||
|
}
|
||||||
290
model/renderers/olmo3_test.go
Normal file
290
model/renderers/olmo3_test.go
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
package renderers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOlmo3Renderer(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
msgs []api.Message
|
||||||
|
tools []api.Tool
|
||||||
|
expected string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "basic without system - adds default system",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Hello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with system message no tools",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello!"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful assistant.<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Hello!<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with system message and tools",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "What is the weather?"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"What is the weather?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "default system with tools - includes function instruction",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "What is the weather?"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful function-calling AI assistant. " +
|
||||||
|
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||||
|
`<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"What is the weather?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "assistant with tool calls - function call syntax",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "What is the weather in SF?"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
Content: "Let me check the weather.",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"location": "San Francisco",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Description: "Get the current weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Required: []string{"location"},
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {Type: api.PropertyType{"string"}, Description: "The city"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
`You are a helpful assistant.<functions>[{"type": "function", "function": {"name": "get_weather", "description": "Get the current weather", "parameters": {"type": "object", "required": ["location"], "properties": {"location": {"type": "string", "description": "The city"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"What is the weather in SF?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
`Let me check the weather.<function_calls>get_weather(location="San Francisco")</function_calls><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>environment\n" +
|
||||||
|
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multi-turn conversation",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "system", Content: "You are a helpful assistant."},
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "Hi there!"},
|
||||||
|
{Role: "user", Content: "How are you?"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful assistant.<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Hello<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"Hi there!<|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"How are you?<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parallel tool calls - newline separated",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Get weather in SF and NYC"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "San Francisco"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "call_2",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: map[string]any{"location": "New York"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{Role: "tool", Content: `{"temperature": 68}`, ToolName: "get_weather"},
|
||||||
|
{Role: "tool", Content: `{"temperature": 55}`, ToolName: "get_weather"},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"location": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful function-calling AI assistant. " +
|
||||||
|
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||||
|
`<functions>[{"type": "function", "function": {"name": "get_weather", "parameters": {"type": "object", "properties": {"location": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Get weather in SF and NYC<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
`<function_calls>get_weather(location="San Francisco")` + "\n" +
|
||||||
|
`get_weather(location="New York")</function_calls><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>environment\n" +
|
||||||
|
`{"temperature": 68}<|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>environment\n" +
|
||||||
|
`{"temperature": 55}<|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tool call with multiple arguments",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Book a flight"},
|
||||||
|
{
|
||||||
|
Role: "assistant",
|
||||||
|
ToolCalls: []api.ToolCall{
|
||||||
|
{
|
||||||
|
ID: "call_1",
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "book_flight",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"from": "SFO",
|
||||||
|
"to": "NYC",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
tools: []api.Tool{
|
||||||
|
{
|
||||||
|
Type: "function",
|
||||||
|
Function: api.ToolFunction{
|
||||||
|
Name: "book_flight",
|
||||||
|
Parameters: api.ToolFunctionParameters{
|
||||||
|
Type: "object",
|
||||||
|
Properties: map[string]api.ToolProperty{
|
||||||
|
"from": {Type: api.PropertyType{"string"}},
|
||||||
|
"to": {Type: api.PropertyType{"string"}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful function-calling AI assistant. " +
|
||||||
|
"You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions." +
|
||||||
|
`<functions>[{"type": "function", "function": {"name": "book_flight", "parameters": {"type": "object", "properties": {"from": {"type": "string"}, "to": {"type": "string"}}}}}]</functions><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Book a flight<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
`<function_calls>book_flight(from="SFO", to="NYC")</function_calls><|im_end|>` + "\n" +
|
||||||
|
"<|im_start|>assistant\n\n",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "assistant prefill - no generation prompt",
|
||||||
|
msgs: []api.Message{
|
||||||
|
{Role: "user", Content: "Hello"},
|
||||||
|
{Role: "assistant", Content: "Hi there!"},
|
||||||
|
},
|
||||||
|
expected: "<|im_start|>system\n" +
|
||||||
|
"You are a helpful function-calling AI assistant. You do not currently have access to any functions. <functions></functions><|im_end|>\n" +
|
||||||
|
"<|im_start|>user\n" +
|
||||||
|
"Hello<|im_end|>\n" +
|
||||||
|
"<|im_start|>assistant\n" +
|
||||||
|
"Hi there!",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
rendered, err := (&Olmo3Renderer{}).Render(tt.msgs, tt.tools, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||||
|
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -59,8 +59,8 @@ func rendererForName(name string) Renderer {
|
|||||||
case "cogito":
|
case "cogito":
|
||||||
renderer := &CogitoRenderer{isThinking: true}
|
renderer := &CogitoRenderer{isThinking: true}
|
||||||
return renderer
|
return renderer
|
||||||
case "intellect-3":
|
case "olmo3":
|
||||||
renderer := &Intellect3Renderer{}
|
renderer := &Olmo3Renderer{}
|
||||||
return renderer
|
return renderer
|
||||||
default:
|
default:
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -110,6 +110,7 @@ func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.Thi
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
slog.Debug("rendered prompt", "renderer", m.Config.Renderer, "prompt", rendered)
|
||||||
return rendered, nil
|
return rendered, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user