renderers/parsers: olmo3 instruct (#13383)

This commit is contained in:
Parth Sareen
2025-12-09 11:12:27 -08:00
committed by GitHub
parent 0c5e5f6630
commit 2bccf8c624
6 changed files with 1390 additions and 0 deletions

465
model/parsers/olmo3.go Normal file
View File

@@ -0,0 +1,465 @@
package parsers
import (
"context"
"fmt"
"log/slog"
"regexp"
"strconv"
"strings"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
type olmo3ParserState int
const (
olmo3StateContent olmo3ParserState = iota
olmo3StateToolCalls
olmo3StateToolCallsDone
)
const (
olmo3FuncCallsOpenTag = "<function_calls>"
olmo3FuncCallsCloseTag = "</function_calls>"
)
type Olmo3Parser struct {
state olmo3ParserState
buffer strings.Builder
}
func (p *Olmo3Parser) HasToolSupport() bool {
return true
}
func (p *Olmo3Parser) HasThinkingSupport() bool {
return false
}
func (p *Olmo3Parser) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
p.state = olmo3StateContent
return tools
}
type olmo3ParserEvent interface {
isOlmo3ParserEvent()
}
type olmo3ParserEventContent struct {
content string
}
type olmo3ParserEventToolCalls struct {
calls []api.ToolCall
}
func (olmo3ParserEventContent) isOlmo3ParserEvent() {}
func (olmo3ParserEventToolCalls) isOlmo3ParserEvent() {}
func (p *Olmo3Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
if done {
// Drain any remaining content
bufStr := p.buffer.String()
p.buffer.Reset()
if p.state == olmo3StateContent && len(bufStr) > 0 {
return bufStr, "", nil, nil
}
return "", "", nil, nil
}
events := p.parseEvents()
var contentSb strings.Builder
var allCalls []api.ToolCall
for _, event := range events {
switch event := event.(type) {
case olmo3ParserEventContent:
contentSb.WriteString(event.content)
case olmo3ParserEventToolCalls:
allCalls = append(allCalls, event.calls...)
}
}
return contentSb.String(), "", allCalls, nil
}
func (p *Olmo3Parser) parseEvents() []olmo3ParserEvent {
var all []olmo3ParserEvent
keepLooping := true
for keepLooping {
var events []olmo3ParserEvent
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "olmo3 events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func (p *Olmo3Parser) eat() ([]olmo3ParserEvent, bool) {
var events []olmo3ParserEvent
bufStr := p.buffer.String()
if bufStr == "" {
return events, false
}
switch p.state {
case olmo3StateContent:
if strings.Contains(bufStr, olmo3FuncCallsOpenTag) {
// Found <function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsOpenTag, 2)
content := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCalls
if len(content) > 0 {
events = append(events, olmo3ParserEventContent{content: content})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsOpenTag); overlapLen > 0 {
// Partial <function_calls> tag - withhold ambiguous content
unambiguous := bufStr[:len(bufStr)-overlapLen]
ambiguous := bufStr[len(bufStr)-overlapLen:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, olmo3ParserEventContent{content: unambiguous})
}
return events, false
} else {
// Regular content - emit all
p.buffer.Reset()
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
case olmo3StateToolCalls:
if strings.Contains(bufStr, olmo3FuncCallsCloseTag) {
// Found </function_calls> tag
split := strings.SplitN(bufStr, olmo3FuncCallsCloseTag, 2)
toolCallsStr := split[0]
remaining := split[1]
p.buffer.Reset()
p.buffer.WriteString(remaining)
p.state = olmo3StateToolCallsDone
// Parse the function calls
calls, err := parseOlmo3FunctionCalls(toolCallsStr)
if err != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "failed to parse olmo3 function calls", "error", err, "content", toolCallsStr)
} else if len(calls) > 0 {
events = append(events, olmo3ParserEventToolCalls{calls: calls})
}
return events, true
} else if overlapLen := overlap(bufStr, olmo3FuncCallsCloseTag); overlapLen > 0 {
// Partial </function_calls> tag - wait for more
return events, false
}
// Still collecting tool calls, wait for close tag
return events, false
case olmo3StateToolCallsDone:
// After tool calls, emit remaining content
p.buffer.Reset()
p.state = olmo3StateContent
if len(bufStr) > 0 {
events = append(events, olmo3ParserEventContent{content: bufStr})
}
return events, false
}
return events, false
}
// parseOlmo3FunctionCalls parses function calls in Python-esque format:
// func_name(arg1="value1", arg2=123)
// Multiple calls are separated by newlines
func parseOlmo3FunctionCalls(s string) ([]api.ToolCall, error) {
var calls []api.ToolCall
s = strings.TrimSpace(s)
if s == "" {
return calls, nil
}
// Split by newlines for multiple function calls
lines := strings.Split(s, "\n")
for _, line := range lines {
line = strings.TrimSpace(line)
if line == "" {
continue
}
call, err := parseOlmo3SingleFunctionCall(line)
if err != nil {
return nil, fmt.Errorf("failed to parse function call %q: %w", line, err)
}
calls = append(calls, call)
}
return calls, nil
}
// Regex to match function call: func_name(args)
var funcCallRegex = regexp.MustCompile(`^(\w+)\((.*)\)$`)
func parseOlmo3SingleFunctionCall(s string) (api.ToolCall, error) {
matches := funcCallRegex.FindStringSubmatch(s)
if matches == nil {
return api.ToolCall{}, fmt.Errorf("invalid function call format")
}
funcName := matches[1]
argsStr := matches[2]
args, err := parseOlmo3Arguments(argsStr)
if err != nil {
return api.ToolCall{}, fmt.Errorf("failed to parse arguments: %w", err)
}
return api.ToolCall{
Function: api.ToolCallFunction{
Name: funcName,
Arguments: args,
},
}, nil
}
// parseOlmo3Arguments parses comma-separated key=value pairs
// Handles nested parentheses, brackets, braces, and quoted strings
func parseOlmo3Arguments(s string) (map[string]any, error) {
args := make(map[string]any)
s = strings.TrimSpace(s)
if s == "" {
return args, nil
}
// Split by commas, but respect nested structures and quotes
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find the first = sign
eqIdx := strings.Index(part, "=")
if eqIdx == -1 {
return nil, fmt.Errorf("invalid argument format: %s", part)
}
key := strings.TrimSpace(part[:eqIdx])
valueStr := strings.TrimSpace(part[eqIdx+1:])
value, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for %s: %w", key, err)
}
args[key] = value
}
return args, nil
}
// splitArguments splits arguments by commas, respecting quotes and nested structures
func splitArguments(s string) []string {
var parts []string
var current strings.Builder
depth := 0
inString := false
stringChar := byte(0)
escaped := false
for i := range s {
c := s[i]
if escaped {
current.WriteByte(c)
escaped = false
continue
}
if c == '\\' && inString {
current.WriteByte(c)
escaped = true
continue
}
if (c == '"' || c == '\'') && !inString {
inString = true
stringChar = c
current.WriteByte(c)
continue
}
if c == stringChar && inString {
inString = false
stringChar = 0
current.WriteByte(c)
continue
}
if !inString {
switch c {
case '(', '[', '{':
depth++
current.WriteByte(c)
case ')', ']', '}':
depth--
current.WriteByte(c)
case ',':
if depth == 0 {
parts = append(parts, current.String())
current.Reset()
continue
}
current.WriteByte(c)
default:
current.WriteByte(c)
}
} else {
current.WriteByte(c)
}
}
if current.Len() > 0 {
parts = append(parts, current.String())
}
return parts
}
// parseOlmo3Value parses a value which can be a string, number, boolean, null, array, or object
func parseOlmo3Value(s string) (any, error) {
s = strings.TrimSpace(s)
// Check for quoted string
if (strings.HasPrefix(s, `"`) && strings.HasSuffix(s, `"`)) ||
(strings.HasPrefix(s, `'`) && strings.HasSuffix(s, `'`)) {
// Remove quotes and unescape
inner := s[1 : len(s)-1]
return unescapeString(inner), nil
}
// Check for boolean
if s == "true" || s == "True" {
return true, nil
}
if s == "false" || s == "False" {
return false, nil
}
// Check for null/None
if s == "null" || s == "None" || s == "nil" {
return nil, nil
}
// Check for number
if i, err := strconv.ParseInt(s, 10, 64); err == nil {
return i, nil
}
if f, err := strconv.ParseFloat(s, 64); err == nil {
return f, nil
}
// Check for array [...]
if strings.HasPrefix(s, "[") && strings.HasSuffix(s, "]") {
return parseOlmo3Array(s[1 : len(s)-1])
}
// Check for object {...}
if strings.HasPrefix(s, "{") && strings.HasSuffix(s, "}") {
return parseOlmo3Object(s[1 : len(s)-1])
}
// Default to string without quotes
return s, nil
}
func parseOlmo3Array(s string) ([]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return []any{}, nil
}
parts := splitArguments(s)
var arr []any
for _, part := range parts {
val, err := parseOlmo3Value(part)
if err != nil {
return nil, err
}
arr = append(arr, val)
}
return arr, nil
}
func parseOlmo3Object(s string) (map[string]any, error) {
s = strings.TrimSpace(s)
if s == "" {
return map[string]any{}, nil
}
// Objects use key: value or "key": value format
obj := make(map[string]any)
parts := splitArguments(s)
for _, part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
// Find colon separator
colonIdx := strings.Index(part, ":")
if colonIdx == -1 {
return nil, fmt.Errorf("invalid object entry: %s", part)
}
keyStr := strings.TrimSpace(part[:colonIdx])
valueStr := strings.TrimSpace(part[colonIdx+1:])
// Remove quotes from key if present
if (strings.HasPrefix(keyStr, `"`) && strings.HasSuffix(keyStr, `"`)) ||
(strings.HasPrefix(keyStr, `'`) && strings.HasSuffix(keyStr, `'`)) {
keyStr = keyStr[1 : len(keyStr)-1]
}
val, err := parseOlmo3Value(valueStr)
if err != nil {
return nil, fmt.Errorf("failed to parse value for key %s: %w", keyStr, err)
}
obj[keyStr] = val
}
return obj, nil
}
func unescapeString(s string) string {
// Handle common escape sequences
s = strings.ReplaceAll(s, `\\`, "\x00") // Placeholder for backslash
s = strings.ReplaceAll(s, `\"`, `"`)
s = strings.ReplaceAll(s, `\'`, `'`)
s = strings.ReplaceAll(s, `\n`, "\n")
s = strings.ReplaceAll(s, `\t`, "\t")
s = strings.ReplaceAll(s, `\r`, "\r")
s = strings.ReplaceAll(s, "\x00", `\`) // Restore backslash
return s
}

483
model/parsers/olmo3_test.go Normal file
View File

@@ -0,0 +1,483 @@
package parsers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestOlmo3Parser(t *testing.T) {
tests := []struct {
name string
input string
expectedContent string
expectedThinking string
expectedCalls []api.ToolCall
}{
{
name: "simple content",
input: "Hello, how can I help you?",
expectedContent: "Hello, how can I help you?",
},
{
name: "simple tool call",
input: `<function_calls>get_weather(location="San Francisco")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
},
},
{
name: "content then tool call",
input: `Let me check the weather.<function_calls>get_weather(location="NYC")</function_calls>`,
expectedContent: "Let me check the weather.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call with multiple arguments",
input: `<function_calls>book_flight(from="SFO", to="NYC", date="2024-01-15")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "book_flight",
Arguments: map[string]any{
"from": "SFO",
"to": "NYC",
"date": "2024-01-15",
},
},
},
},
},
{
name: "multiple tool calls",
input: `<function_calls>get_weather(location="San Francisco")
get_weather(location="New York")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "San Francisco"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "New York"},
},
},
},
},
{
name: "tool call with numeric argument",
input: `<function_calls>set_temperature(value=72)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_temperature",
Arguments: map[string]any{"value": int64(72)},
},
},
},
},
{
name: "tool call with float argument",
input: `<function_calls>set_price(amount=19.99)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "set_price",
Arguments: map[string]any{"amount": 19.99},
},
},
},
},
{
name: "tool call with boolean argument",
input: `<function_calls>toggle_setting(enabled=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "toggle_setting",
Arguments: map[string]any{"enabled": true},
},
},
},
},
{
name: "tool call with null argument",
input: `<function_calls>clear_value(field=null)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "clear_value",
Arguments: map[string]any{"field": nil},
},
},
},
},
{
name: "tool call with array argument",
input: `<function_calls>process_items(items=["apple", "banana", "cherry"])</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "process_items",
Arguments: map[string]any{"items": []any{"apple", "banana", "cherry"}},
},
},
},
},
{
name: "tool call with dict argument",
input: `<function_calls>update_config(settings={"theme": "dark", "fontSize": 14})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "update_config",
Arguments: map[string]any{
"settings": map[string]any{
"theme": "dark",
"fontSize": int64(14),
},
},
},
},
},
},
{
name: "tool call with nested dict",
input: `<function_calls>create_request(data={"user": {"name": "John", "age": 30}, "active": true})</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_request",
Arguments: map[string]any{
"data": map[string]any{
"user": map[string]any{
"name": "John",
"age": int64(30),
},
"active": true,
},
},
},
},
},
},
{
name: "tool call with no arguments",
input: `<function_calls>get_current_time()</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_time",
Arguments: map[string]any{},
},
},
},
},
{
name: "tool call with single quotes",
input: `<function_calls>search(query='hello world')</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": "hello world"},
},
},
},
},
{
name: "tool call with escaped quotes",
input: `<function_calls>search(query="say \"hello\"")</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "search",
Arguments: map[string]any{"query": `say "hello"`},
},
},
},
},
{
name: "tool call with mixed argument types",
input: `<function_calls>create_user(name="John", age=30, active=true)</function_calls>`,
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "create_user",
Arguments: map[string]any{
"name": "John",
"age": int64(30),
"active": true,
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
content, thinking, calls, err := p.Add(tt.input, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Drain remaining content
finalContent, finalThinking, finalCalls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
content += finalContent
thinking += finalThinking
calls = append(calls, finalCalls...)
if diff := cmp.Diff(content, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(thinking, tt.expectedThinking); diff != "" {
t.Errorf("thinking mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(calls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_Streaming(t *testing.T) {
tests := []struct {
name string
chunks []string
expectedContent string
expectedCalls []api.ToolCall
}{
{
name: "streaming content",
chunks: []string{"Hello, ", "how ", "can I help?"},
expectedContent: "Hello, how can I help?",
},
{
name: "streaming tool call",
chunks: []string{"<function_", "calls>get_weather", "(location=\"SF\")", "</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "streaming content then tool call",
chunks: []string{"Let me check.", "<function_calls>", "get_weather(location=\"NYC\")", "</function_calls>"},
expectedContent: "Let me check.",
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "NYC"},
},
},
},
},
{
name: "tool call tag split across chunks",
chunks: []string{"<func", "tion_calls>test()</function_calls>"},
expectedCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "test",
Arguments: map[string]any{},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p := &Olmo3Parser{}
p.Init(nil, nil, nil)
var allContent string
var allCalls []api.ToolCall
for _, chunk := range tt.chunks {
content, _, calls, err := p.Add(chunk, false)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
}
// Drain
content, _, calls, err := p.Add("", true)
if err != nil {
t.Fatalf("unexpected error on done: %v", err)
}
allContent += content
allCalls = append(allCalls, calls...)
if diff := cmp.Diff(allContent, tt.expectedContent); diff != "" {
t.Errorf("content mismatch (-got +want):\n%s", diff)
}
if diff := cmp.Diff(allCalls, tt.expectedCalls); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestOlmo3Parser_HasToolSupport(t *testing.T) {
p := &Olmo3Parser{}
if !p.HasToolSupport() {
t.Error("expected HasToolSupport to return true")
}
}
func TestOlmo3Parser_HasThinkingSupport(t *testing.T) {
p := &Olmo3Parser{}
if p.HasThinkingSupport() {
t.Error("expected HasThinkingSupport to return false")
}
}
func TestParseOlmo3FunctionCalls(t *testing.T) {
tests := []struct {
name string
input string
expected []api.ToolCall
wantErr bool
}{
{
name: "simple call",
input: `get_weather(location="SF")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
},
},
{
name: "multiple args",
input: `send_email(to="user@example.com", subject="Hello", body="Test message")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "send_email",
Arguments: map[string]any{
"to": "user@example.com",
"subject": "Hello",
"body": "Test message",
},
},
},
},
},
{
name: "multiple calls with newlines",
input: `get_weather(location="SF")
get_time(timezone="PST")`,
expected: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: map[string]any{"location": "SF"},
},
},
{
Function: api.ToolCallFunction{
Name: "get_time",
Arguments: map[string]any{"timezone": "PST"},
},
},
},
},
{
name: "empty input",
input: "",
expected: nil,
},
{
name: "whitespace only",
input: " \n ",
expected: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
calls, err := parseOlmo3FunctionCalls(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("parseOlmo3FunctionCalls() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(calls, tt.expected); diff != "" {
t.Errorf("calls mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestParseOlmo3Value(t *testing.T) {
tests := []struct {
name string
input string
expected any
}{
{"string double quotes", `"hello"`, "hello"},
{"string single quotes", `'hello'`, "hello"},
{"integer", "42", int64(42)},
{"negative integer", "-10", int64(-10)},
{"float", "3.14", 3.14},
{"boolean true", "true", true},
{"boolean True", "True", true},
{"boolean false", "false", false},
{"null", "null", nil},
{"None", "None", nil},
{"empty array", "[]", []any{}},
{"array with strings", `["a", "b"]`, []any{"a", "b"}},
{"array with numbers", "[1, 2, 3]", []any{int64(1), int64(2), int64(3)}},
{"empty object", "{}", map[string]any{}},
{"simple object", `{"name": "John"}`, map[string]any{"name": "John"}},
{"object with number", `{"age": 30}`, map[string]any{"age": int64(30)}},
{"object with multiple keys", `{"a": 1, "b": 2}`, map[string]any{"a": int64(1), "b": int64(2)}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := parseOlmo3Value(tt.input)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if diff := cmp.Diff(result, tt.expected); diff != "" {
t.Errorf("value mismatch (-got +want):\n%s", diff)
}
})
}
}

View File

@@ -58,6 +58,8 @@ func ParserForName(name string) Parser {
return harmony.NewHarmonyMessageHandler()
case "cogito":
return &CogitoParser{}
case "olmo3":
return &Olmo3Parser{}
case "olmo3-think":
return &Olmo3ThinkParser{}
default: