diff --git a/model/parsers/glm46.go b/model/parsers/glm46.go new file mode 100644 index 000000000..9f08ad512 --- /dev/null +++ b/model/parsers/glm46.go @@ -0,0 +1,422 @@ +package parsers + +import ( + "context" + "encoding/json" + "log/slog" + "regexp" + "strings" + "unicode" + "unicode/utf8" + + "github.com/ollama/ollama/api" + "github.com/ollama/ollama/logutil" +) + +type glm46ParserState int + +const ( + thinkOpenTag = "" + thinkCloseTag = "" + glmToolOpenTag = "" + glmToolCloseTag = "" + argKeyOpenTag = "" + argKeyCloseTag = "" + argValueOpenTag = "" + argValueCloseTag = "" +) + +const ( + glm46ParserState_LookingForTags glm46ParserState = iota + glm46ParserState_CollectingThinking + glm46ParserState_CollectingToolCall +) + +type GLM46Parser struct { + state glm46ParserState + acc strings.Builder + tools []api.Tool +} + +func (p *GLM46Parser) HasToolSupport() bool { + return true +} + +func (p *GLM46Parser) HasThinkingSupport() bool { + return true +} + +func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool { + p.tools = tools + return tools +} + +func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) { + p.acc.WriteString(s) + + events := p.parseEvents() + + var toolCalls []api.ToolCall + var contentBuilder strings.Builder + var thinkingBuilder strings.Builder + + for _, event := range events { + switch event := event.(type) { + case glm46EventRawToolCall: + toolCall, err := parseGLMToolCall(event, p.tools) + if err != nil { + slog.Warn("glm46 tool call parsing failed", "error", err) + return "", "", nil, err + } + toolCalls = append(toolCalls, toolCall) + case glm46EventContent: + contentBuilder.WriteString(event.content) + case glm46EventThinking: + thinkingBuilder.WriteString(event.thinking) + } + } + + return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil +} + +func (p *GLM46Parser) parseEvents() []glm46Event { + var all []glm46Event + + keepLooping := true + for keepLooping { + var events []glm46Event + events, keepLooping = eatGLM(p) + if len(events) > 0 { + all = append(all, events...) + } + } + + if len(all) > 0 { + slog.Log(context.TODO(), logutil.LevelTrace, "glm46 events parsed", "events", all, "state", p.state, "acc", p.acc.String()) + } + + return all +} + +type glm46Event interface { + isGLM46Event() +} + +type glm46EventRawToolCall struct { + raw string +} + +type glm46EventContent struct { + content string +} + +type glm46EventThinking struct { + thinking string +} + +func (glm46EventContent) isGLM46Event() {} +func (glm46EventRawToolCall) isGLM46Event() {} +func (glm46EventThinking) isGLM46Event() {} + +func eatGLM(p *GLM46Parser) ([]glm46Event, bool) { + var events []glm46Event + + switch p.state { + case glm46ParserState_LookingForTags: + buf := p.acc.String() + + // Check for thinking open tag first + if strings.Contains(buf, thinkOpenTag) { + split := strings.SplitN(buf, thinkOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, glm46EventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = glm46ParserState_CollectingThinking + return events, true + } + + // Check for tool call open tag + if strings.Contains(buf, glmToolOpenTag) { + split := strings.SplitN(buf, glmToolOpenTag, 2) + before := split[0] + before = strings.TrimRightFunc(before, unicode.IsSpace) + if len(before) > 0 { + events = append(events, glm46EventContent{content: before}) + } + after := split[1] + p.acc.Reset() + p.acc.WriteString(after) + p.state = glm46ParserState_CollectingToolCall + return events, true + } + + // Check for partial tags + if overlap := glmOverlap(buf, thinkOpenTag); overlap > 0 { + beforePartialTag := buf[:len(buf)-overlap] + trailingWhitespaceLen := glmTrailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := buf[:ambiguousStart] + ambiguous := buf[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventContent{content: unambiguous}) + } + return events, false + } + + if overlap := glmOverlap(buf, glmToolOpenTag); overlap > 0 { + beforePartialTag := buf[:len(buf)-overlap] + trailingWhitespaceLen := glmTrailingWhitespaceLen(beforePartialTag) + ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen + unambiguous := buf[:ambiguousStart] + ambiguous := buf[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventContent{content: unambiguous}) + } + return events, false + } + + // No tags found, emit content but withhold trailing whitespace + whitespaceLen := glmTrailingWhitespaceLen(buf) + ambiguousStart := len(buf) - whitespaceLen + unambiguous := buf[:ambiguousStart] + ambiguous := buf[ambiguousStart:] + p.acc.Reset() + p.acc.WriteString(ambiguous) + if len(unambiguous) > 0 { + events = append(events, glm46EventContent{content: unambiguous}) + } + return events, false + + case glm46ParserState_CollectingThinking: + if strings.Contains(p.acc.String(), thinkCloseTag) { + split := strings.SplitN(p.acc.String(), thinkCloseTag, 2) + thinkingContent := split[0] + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, glm46EventThinking{thinking: thinkingContent}) + p.state = glm46ParserState_LookingForTags + return events, true + } + return events, false + + case glm46ParserState_CollectingToolCall: + if strings.Contains(p.acc.String(), glmToolCloseTag) { + split := strings.SplitN(p.acc.String(), glmToolCloseTag, 2) + toolCallContent := split[0] + if len(toolCallContent) == 0 { + slog.Warn("glm46 tool call closing tag found but no content before it") + } + after := strings.TrimLeftFunc(split[1], unicode.IsSpace) + p.acc.Reset() + p.acc.WriteString(after) + events = append(events, glm46EventRawToolCall{raw: toolCallContent}) + p.state = glm46ParserState_LookingForTags + return events, true + } + return events, false + + default: + panic("unreachable") + } +} + +var ( + glmFunctionNameRegex = regexp.MustCompile(`^([^\n<]+)`) + glmArgKeyRegex = regexp.MustCompile(`(.*?)`) + glmArgValueRegex = regexp.MustCompile(`(.*?)`) +) + +// parseGLMToolCall parses a raw GLM tool call string into an api.ToolCall. +// The raw string has the format: +// {function-name} +// {arg-key-1} +// {arg-value-1} +// {arg-key-2} +// {arg-value-2} +// ... +func parseGLMToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) { + toolCall := api.ToolCall{} + + // Extract function name (first line or until first <) + functionNameMatch := glmFunctionNameRegex.FindStringSubmatch(raw.raw) + if len(functionNameMatch) < 2 { + return api.ToolCall{}, nil + } + + functionName := strings.TrimSpace(functionNameMatch[1]) + toolCall.Function = api.ToolCallFunction{ + Name: functionName, + } + + // Find the matching tool to get parameter types + var matchedTool *api.Tool + for i := range tools { + if tools[i].Function.Name == functionName { + matchedTool = &tools[i] + break + } + } + + // Extract all arg_key and arg_value pairs + argKeys := glmArgKeyRegex.FindAllStringSubmatch(raw.raw, -1) + argValues := glmArgValueRegex.FindAllStringSubmatch(raw.raw, -1) + + if len(argKeys) != len(argValues) { + slog.Warn("glm46 tool call has mismatched arg_key and arg_value counts", "keys", len(argKeys), "values", len(argValues)) + } + + toolCall.Function.Arguments = make(api.ToolCallFunctionArguments) + minLen := min(len(argKeys), len(argValues)) + + for i := 0; i < minLen; i++ { + if len(argKeys[i]) < 2 || len(argValues[i]) < 2 { + continue + } + + key := strings.TrimSpace(argKeys[i][1]) + value := argValues[i][1] + + // Trim leading and trailing newlines from value (following reference implementation) + value = strings.TrimPrefix(value, "\n") + value = strings.TrimSuffix(value, "\n") + + // Look up the parameter type if we found the tool + var paramType api.PropertyType + if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil { + if prop, ok := matchedTool.Function.Parameters.Properties[key]; ok { + paramType = prop.Type + } + } + + // Parse the value according to its type + toolCall.Function.Arguments[key] = parseGLMValue(value, paramType) + } + + return toolCall, nil +} + +// longest overlap between suffix of s and prefix of delim +func glmOverlap(s, delim string) int { + max := min(len(delim), len(s)) + for i := max; i > 0; i-- { + if strings.HasSuffix(s, delim[:i]) { + return i + } + } + return 0 +} + +func glmTrailingWhitespaceLen(s string) int { + remaining := s + total := 0 + for len(remaining) > 0 { + r, size := utf8.DecodeLastRuneInString(remaining) + // if it's an invalid utf8 rune, assume it isn't whitespace + if r == utf8.RuneError && size == 1 { + break + } + if !unicode.IsSpace(r) { + break + } + total += size + remaining = remaining[:len(remaining)-size] + } + return total +} + +func parseGLMValue(raw string, paramType api.PropertyType) any { + // Check for null first (case-insensitive) - this takes precedence over any type + if strings.ToLower(raw) == "null" { + return nil + } + + // If no type is specified, try to parse as JSON, otherwise return as string + if len(paramType) == 0 { + var val any + if err := json.Unmarshal([]byte(raw), &val); err == nil { + return val + } + return raw + } + + // Check if any of the specified types match, using type precedence + // Order: boolean -> integer -> number -> array -> object -> string + typeSet := make(map[string]bool) + for _, t := range paramType { + typeSet[t] = true + } + + // Try boolean first (most restrictive) + if typeSet["boolean"] { + lower := strings.ToLower(raw) + switch lower { + case "true": + return true + case "false": + return false + } + // If not a valid boolean but boolean is the only type, return false + if len(paramType) == 1 { + return false + } + } + + // Try parsing as JSON for complex types + var jsonVal any + if err := json.Unmarshal([]byte(raw), &jsonVal); err == nil { + // Check if the parsed type matches any of the expected types + switch v := jsonVal.(type) { + case float64: + if typeSet["number"] { + return v + } + if typeSet["integer"] && v == float64(int64(v)) { + return int64(v) + } + case bool: + if typeSet["boolean"] { + return v + } + case []any: + if typeSet["array"] { + return v + } + case map[string]any: + if typeSet["object"] { + return v + } + case string: + if typeSet["string"] { + return v + } + case nil: + return nil + } + + // If JSON parsed but type doesn't match, check if string is valid + if typeSet["string"] { + return raw + } + + // Return the parsed JSON value as fallback + return jsonVal + } + + // If JSON parsing failed but string is valid, return as string + if typeSet["string"] { + return raw + } + + // Fallback to string + return raw +} diff --git a/model/parsers/parsers.go b/model/parsers/parsers.go index a1d4e8127..57f29f982 100644 --- a/model/parsers/parsers.go +++ b/model/parsers/parsers.go @@ -21,6 +21,9 @@ func ParserForName(name string) Parser { case "qwen3-coder": parser := &Qwen3CoderParser{} return parser + case "glm-4.6": + parser := &GLM46Parser{} + return parser case "passthrough": return &PassthroughParser{} case "harmony": diff --git a/model/renderers/glm46_test.go b/model/renderers/glm46_test.go new file mode 100644 index 000000000..82c3bccc0 --- /dev/null +++ b/model/renderers/glm46_test.go @@ -0,0 +1,216 @@ +package renderers + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/api" +) + +func TestGLM46Renderer(t *testing.T) { + tests := []struct { + name string + messages []api.Message + tools []api.Tool + thinkValue *api.ThinkValue + expected string + }{ + { + name: "basic", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `[gMASK]<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "basic with system message", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Hello, how are you?"}, + }, + expected: `[gMASK]<|system|> +You are a helpful assistant.<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "basic with user assistant user", + messages: []api.Message{ + {Role: "user", Content: "What is the capital of France?"}, + {Role: "assistant", Content: "The capital of France is Paris."}, + {Role: "user", Content: "Fantastic!"}, + }, + expected: `[gMASK]<|user|> +What is the capital of France?<|assistant|> +The capital of France is Paris.<|user|> +Fantastic!<|assistant|>`, + }, + { + name: "tools", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in Tokyo?"}, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: api.PropertyType{"string"}, + Enum: []any{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + }, + expected: `[gMASK]<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +<|system|> +You are a helpful assistant with access to tools.<|user|> +What is the weather like in Tokyo?<|assistant|>`, + }, + { + name: "tool calls", + messages: []api.Message{ + {Role: "system", Content: "You are a helpful assistant with access to tools."}, + {Role: "user", Content: "What is the weather like in Tokyo?"}, + { + Role: "assistant", + ToolCalls: []api.ToolCall{ + { + Function: api.ToolCallFunction{ + Name: "get_weather", + Arguments: api.ToolCallFunctionArguments{ + "location": "Tokyo, Japan", + "unit": "celsius", + }, + }, + }, + }, + }, + { + Role: "tool", + Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}", + ToolName: "get_weather", + }, + { + Role: "assistant", + Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.", + }, + }, + tools: []api.Tool{ + { + Type: "function", + Function: api.ToolFunction{ + Name: "get_weather", + Description: "Get the current weather in a given location", + Parameters: api.ToolFunctionParameters{ + Type: "object", + Required: []string{"location"}, + Properties: map[string]api.ToolProperty{ + "location": { + Type: api.PropertyType{"string"}, + Description: "The city and state, e.g. San Francisco, CA", + }, + "unit": { + Type: api.PropertyType{"string"}, + Enum: []any{"celsius", "fahrenheit"}, + }, + }, + }, + }, + }, + }, + expected: `[gMASK]<|system|> +# Tools + +You may call one or more functions to assist with the user query. + +You are provided with function signatures within XML tags: + +{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}} + + +For each function call, output the function name and arguments within the following XML format: +{function-name} +{arg-key-1} +{arg-value-1} +{arg-key-2} +{arg-value-2} +... +<|system|> +You are a helpful assistant with access to tools.<|user|> +What is the weather like in Tokyo?<|assistant|> + +get_weather +location +Tokyo, Japan +unit +celsius +<|observation|> + +{"temperature": 22, "weather": "partly cloudy", "humidity": 65} +<|assistant|> + +The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`, + }, + { + name: "think true", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: true}, + expected: `[gMASK]<|user|> +Hello, how are you?<|assistant|>`, + }, + { + name: "think false", + messages: []api.Message{ + {Role: "user", Content: "Hello, how are you?"}, + }, + thinkValue: &api.ThinkValue{Value: false}, + expected: `[gMASK]<|user|> +Hello, how are you?<|assistant|> +`, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rendered, err := GLM46Renderer(tt.messages, tt.tools, tt.thinkValue) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(rendered, tt.expected); diff != "" { + t.Errorf("mismatch (-got +want):\n%s", diff) + t.Logf("Got:\n%s", rendered) + t.Logf("Expected:\n%s", tt.expected) + } + }) + } +} diff --git a/model/renderers/gml46.go b/model/renderers/gml46.go new file mode 100644 index 000000000..d18e01873 --- /dev/null +++ b/model/renderers/gml46.go @@ -0,0 +1,109 @@ +package renderers + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/ollama/ollama/api" +) + +func GLM46Renderer(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) { + var sb strings.Builder + + sb.WriteString("[gMASK]") + + var lastUserIndex int + for i, message := range messages { + if message.Role == "user" { + lastUserIndex = i + } + } + + if len(tools) > 0 { + sb.WriteString("<|system|>\n") + sb.WriteString("# Tools\n\n") + sb.WriteString("You may call one or more functions to assist with the user query.\n\n") + sb.WriteString("You are provided with function signatures within XML tags:\n") + sb.WriteString("\n") + for _, tool := range tools { + d, _ := json.Marshal(tool) + sb.WriteString(string(d) + "\n") + } + sb.WriteString("\n\n") + sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n") + sb.WriteString("{function-name}\n") + sb.WriteString("{arg-key-1}\n") + sb.WriteString("{arg-value-1}\n") + sb.WriteString("{arg-key-2}\n") + sb.WriteString("{arg-value-2}\n") + sb.WriteString("...\n") + sb.WriteString("") + } + + for i, message := range messages { + switch message.Role { + case "user": + sb.WriteString("<|user|>\n") + sb.WriteString(message.Content) + if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") { + sb.WriteString("/nothink") + } + case "assistant": + sb.WriteString("<|assistant|>") + if i > lastUserIndex { + if message.Thinking != "" { + sb.WriteString("\n" + message.Thinking + "") + } else { + sb.WriteString("\n") + } + } + if message.Content != "" { + sb.WriteString("\n" + message.Content) + } + if len(message.ToolCalls) > 0 { + for _, toolCall := range message.ToolCalls { + sb.WriteString("\n" + toolCall.Function.Name + "\n") + for key, value := range toolCall.Function.Arguments { + sb.WriteString("" + key + "\n") + + var valueStr string + if str, ok := value.(string); ok { + valueStr = str + } else { + jsonBytes, err := json.Marshal(value) + if err != nil { + valueStr = fmt.Sprintf("%v", value) + } else { + valueStr = string(jsonBytes) + } + } + + sb.WriteString("" + valueStr + "\n") + } + + sb.WriteString("") + } + } + case "tool": + if i == 0 || messages[i-1].Role != "tool" { + sb.WriteString("<|observation|>") + } + sb.WriteString("\n\n") + sb.WriteString(message.Content) + sb.WriteString("\n") + case "system": + sb.WriteString("<|system|>\n") + sb.WriteString(message.Content) + } + } + + // Add generation prompt + sb.WriteString("<|assistant|>") + fmt.Println("thinkValue", thinkValue, thinkValue.Bool()) + if thinkValue != nil && !thinkValue.Bool() { + sb.WriteString("\n") + } + + return sb.String(), nil +} diff --git a/model/renderers/renderer.go b/model/renderers/renderer.go index 2dfb51e49..24ad4ee35 100644 --- a/model/renderers/renderer.go +++ b/model/renderers/renderer.go @@ -20,6 +20,8 @@ func rendererForName(name string) rendererFunc { switch name { case "qwen3-coder": return Qwen3CoderRenderer + case "glm-4.6": + return GLM46Renderer default: return nil }