no parser yet
This commit is contained in:
parent
9ff8e5a64d
commit
38ed7c7a4f
|
|
@ -4,79 +4,76 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
"regexp"
|
|
||||||
"strings"
|
"strings"
|
||||||
"unicode"
|
"unicode"
|
||||||
"unicode/utf8"
|
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/logutil"
|
"github.com/ollama/ollama/logutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
type glm46ParserState int
|
|
||||||
|
|
||||||
const (
|
const (
|
||||||
thinkOpenTag = "<think>"
|
glm46CollectingContent glm46ParserState = iota
|
||||||
thinkCloseTag = "</think>"
|
CollectingThinkingContent
|
||||||
glmToolOpenTag = "<tool_call>"
|
CollectingToolContent
|
||||||
glmToolCloseTag = "</tool_call>"
|
|
||||||
argKeyOpenTag = "<arg_key>"
|
|
||||||
argKeyCloseTag = "</arg_key>"
|
|
||||||
argValueOpenTag = "<arg_value>"
|
|
||||||
argValueCloseTag = "</arg_value>"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
glm46ParserState_LookingForTags glm46ParserState = iota
|
thinkingCloseTag = "</think>"
|
||||||
glm46ParserState_CollectingThinking
|
|
||||||
glm46ParserState_CollectingToolCall
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// TODO(gguo): add a field for isThinking
|
||||||
type GLM46Parser struct {
|
type GLM46Parser struct {
|
||||||
state glm46ParserState
|
state qwenParserState
|
||||||
acc strings.Builder
|
buffer strings.Builder
|
||||||
tools []api.Tool
|
tools []api.Tool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) HasToolSupport() bool {
|
func (p *GLM46Parser) HasToolSupport() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(gguo): changes this to reference an objects param
|
||||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||||
p.tools = tools
|
p.tools = tools
|
||||||
|
// p.state = p.initialState()
|
||||||
return tools
|
return tools
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
type glm46EventThinkingContent struct {
|
||||||
p.acc.WriteString(s)
|
content string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||||
|
|
||||||
|
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||||
|
p.buffer.WriteString(s)
|
||||||
events := p.parseEvents()
|
events := p.parseEvents()
|
||||||
|
|
||||||
var toolCalls []api.ToolCall
|
var toolCalls []api.ToolCall
|
||||||
var contentBuilder strings.Builder
|
var sb strings.Builder
|
||||||
var thinkingBuilder strings.Builder
|
|
||||||
|
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
switch event := event.(type) {
|
switch event := event.(type) {
|
||||||
case glm46EventRawToolCall:
|
case glm46EventRawToolCall:
|
||||||
toolCall, err := parseGLMToolCall(event, p.tools)
|
toolCall, err := parseJSONToolCall(event, p.tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
slog.Warn("glm46 tool call parsing failed", "error", err)
|
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||||
return "", "", nil, err
|
return "", "", nil, err
|
||||||
}
|
}
|
||||||
toolCalls = append(toolCalls, toolCall)
|
toolCalls = append(toolCalls, toolCall)
|
||||||
|
case glm46EventThinkingContent:
|
||||||
|
sb.WriteString(event.content)
|
||||||
case glm46EventContent:
|
case glm46EventContent:
|
||||||
contentBuilder.WriteString(event.content)
|
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||||
case glm46EventThinking:
|
// events, we naively append them together here.
|
||||||
thinkingBuilder.WriteString(event.thinking)
|
sb.WriteString(event.content)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return contentBuilder.String(), thinkingBuilder.String(), toolCalls, nil
|
return sb.String(), "", toolCalls, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||||
|
|
@ -85,338 +82,138 @@ func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||||
keepLooping := true
|
keepLooping := true
|
||||||
for keepLooping {
|
for keepLooping {
|
||||||
var events []glm46Event
|
var events []glm46Event
|
||||||
events, keepLooping = eatGLM(p)
|
events, keepLooping = p.eat()
|
||||||
if len(events) > 0 {
|
if len(events) > 0 {
|
||||||
all = append(all, events...)
|
all = append(all, events...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(all) > 0 {
|
if len(all) > 0 {
|
||||||
slog.Log(context.TODO(), logutil.LevelTrace, "glm46 events parsed", "events", all, "state", p.state, "acc", p.acc.String())
|
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return all
|
return all
|
||||||
}
|
}
|
||||||
|
|
||||||
type glm46Event interface {
|
func emitContentBeforeTag(p *GLM46Parser, events []glm46Event, tag string) []glm46Event {
|
||||||
isGLM46Event()
|
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||||
|
before := split[0]
|
||||||
|
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, glm46EventContent{content: before})
|
||||||
|
}
|
||||||
|
after := split[1]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
return events
|
||||||
}
|
}
|
||||||
|
|
||||||
type glm46EventRawToolCall struct {
|
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||||
raw string
|
|
||||||
}
|
|
||||||
|
|
||||||
type glm46EventContent struct {
|
|
||||||
content string
|
|
||||||
}
|
|
||||||
|
|
||||||
type glm46EventThinking struct {
|
|
||||||
thinking string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (glm46EventContent) isGLM46Event() {}
|
|
||||||
func (glm46EventRawToolCall) isGLM46Event() {}
|
|
||||||
func (glm46EventThinking) isGLM46Event() {}
|
|
||||||
|
|
||||||
func eatGLM(p *GLM46Parser) ([]glm46Event, bool) {
|
|
||||||
var events []glm46Event
|
var events []glm46Event
|
||||||
|
|
||||||
switch p.state {
|
switch p.state {
|
||||||
case glm46ParserState_LookingForTags:
|
case glm46CollectingContent:
|
||||||
buf := p.acc.String()
|
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||||
|
events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||||
// Check for thinking open tag first
|
p.state = glm46CollectingToolContent
|
||||||
if strings.Contains(buf, thinkOpenTag) {
|
|
||||||
split := strings.SplitN(buf, thinkOpenTag, 2)
|
|
||||||
before := split[0]
|
|
||||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
|
||||||
if len(before) > 0 {
|
|
||||||
events = append(events, glm46EventContent{content: before})
|
|
||||||
}
|
|
||||||
after := split[1]
|
|
||||||
p.acc.Reset()
|
|
||||||
p.acc.WriteString(after)
|
|
||||||
p.state = glm46ParserState_CollectingThinking
|
|
||||||
return events, true
|
return events, true
|
||||||
}
|
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||||
|
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||||
// Check for tool call open tag
|
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
if strings.Contains(buf, glmToolOpenTag) {
|
|
||||||
split := strings.SplitN(buf, glmToolOpenTag, 2)
|
|
||||||
before := split[0]
|
|
||||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
|
||||||
if len(before) > 0 {
|
|
||||||
events = append(events, glm46EventContent{content: before})
|
|
||||||
}
|
|
||||||
after := split[1]
|
|
||||||
p.acc.Reset()
|
|
||||||
p.acc.WriteString(after)
|
|
||||||
p.state = glm46ParserState_CollectingToolCall
|
|
||||||
return events, true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for partial tags
|
|
||||||
if overlap := glmOverlap(buf, thinkOpenTag); overlap > 0 {
|
|
||||||
beforePartialTag := buf[:len(buf)-overlap]
|
|
||||||
trailingWhitespaceLen := glmTrailingWhitespaceLen(beforePartialTag)
|
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||||
unambiguous := buf[:ambiguousStart]
|
|
||||||
ambiguous := buf[ambiguousStart:]
|
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||||
p.acc.Reset()
|
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||||
p.acc.WriteString(ambiguous)
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 { // why does qwen3coder not have this here
|
||||||
|
events = append(events, glm46EventContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
} else {
|
||||||
|
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||||
|
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||||
|
|
||||||
|
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||||
|
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
if len(unambiguous) > 0 {
|
if len(unambiguous) > 0 {
|
||||||
events = append(events, glm46EventContent{content: unambiguous})
|
events = append(events, glm46EventContent{content: unambiguous})
|
||||||
}
|
}
|
||||||
return events, false
|
return events, false
|
||||||
}
|
}
|
||||||
|
case CollectingToolContent:
|
||||||
|
if strings.Contains(p.buffer.String(), glm46ToolCloseTag) {
|
||||||
|
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
|
||||||
|
before := split[0]
|
||||||
|
if len(before) == 0 {
|
||||||
|
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||||
|
}
|
||||||
|
|
||||||
if overlap := glmOverlap(buf, glmToolOpenTag); overlap > 0 {
|
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||||
beforePartialTag := buf[:len(buf)-overlap]
|
events = append(events, glm46EventRawToolCall{raw: before})
|
||||||
trailingWhitespaceLen := glmTrailingWhitespaceLen(beforePartialTag)
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
p.state = glm46CollectingContent
|
||||||
|
return events, true
|
||||||
|
} else {
|
||||||
|
return events, false
|
||||||
|
}
|
||||||
|
case glm46CollectingThinkingContent: // so we want to hip the unambiguous stuff
|
||||||
|
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||||
|
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
|
||||||
|
before := split[0]
|
||||||
|
if len(before) == 0 {
|
||||||
|
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||||
|
}
|
||||||
|
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||||
|
if len(before) > 0 {
|
||||||
|
events = append(events, glm46EventThinkingContent{content: before})
|
||||||
|
}
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(after)
|
||||||
|
p.state = glm46CollectingContent
|
||||||
|
return events, true
|
||||||
|
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
|
||||||
|
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||||
|
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||||
unambiguous := buf[:ambiguousStart]
|
|
||||||
ambiguous := buf[ambiguousStart:]
|
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||||
p.acc.Reset()
|
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||||
p.acc.WriteString(ambiguous)
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
if len(unambiguous) > 0 {
|
if len(unambiguous) > 0 {
|
||||||
events = append(events, glm46EventContent{content: unambiguous})
|
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||||
|
}
|
||||||
|
return events, false
|
||||||
|
} else {
|
||||||
|
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||||
|
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||||
|
|
||||||
|
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||||
|
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||||
|
p.buffer.Reset()
|
||||||
|
p.buffer.WriteString(ambiguous)
|
||||||
|
if len(unambiguous) > 0 {
|
||||||
|
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||||
}
|
}
|
||||||
return events, false
|
return events, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// No tags found, emit content but withhold trailing whitespace
|
|
||||||
whitespaceLen := glmTrailingWhitespaceLen(buf)
|
|
||||||
ambiguousStart := len(buf) - whitespaceLen
|
|
||||||
unambiguous := buf[:ambiguousStart]
|
|
||||||
ambiguous := buf[ambiguousStart:]
|
|
||||||
p.acc.Reset()
|
|
||||||
p.acc.WriteString(ambiguous)
|
|
||||||
if len(unambiguous) > 0 {
|
|
||||||
events = append(events, glm46EventContent{content: unambiguous})
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
|
|
||||||
case glm46ParserState_CollectingThinking:
|
|
||||||
if strings.Contains(p.acc.String(), thinkCloseTag) {
|
|
||||||
split := strings.SplitN(p.acc.String(), thinkCloseTag, 2)
|
|
||||||
thinkingContent := split[0]
|
|
||||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
|
||||||
p.acc.Reset()
|
|
||||||
p.acc.WriteString(after)
|
|
||||||
events = append(events, glm46EventThinking{thinking: thinkingContent})
|
|
||||||
p.state = glm46ParserState_LookingForTags
|
|
||||||
return events, true
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
|
|
||||||
case glm46ParserState_CollectingToolCall:
|
|
||||||
if strings.Contains(p.acc.String(), glmToolCloseTag) {
|
|
||||||
split := strings.SplitN(p.acc.String(), glmToolCloseTag, 2)
|
|
||||||
toolCallContent := split[0]
|
|
||||||
if len(toolCallContent) == 0 {
|
|
||||||
slog.Warn("glm46 tool call closing tag found but no content before it")
|
|
||||||
}
|
|
||||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
|
||||||
p.acc.Reset()
|
|
||||||
p.acc.WriteString(after)
|
|
||||||
events = append(events, glm46EventRawToolCall{raw: toolCallContent})
|
|
||||||
p.state = glm46ParserState_LookingForTags
|
|
||||||
return events, true
|
|
||||||
}
|
|
||||||
return events, false
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
panic("unreachable")
|
panic("unreachable")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var (
|
func parseJSONToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||||
glmFunctionNameRegex = regexp.MustCompile(`^([^\n<]+)`)
|
var toolCallFunction api.ToolCallFunction
|
||||||
glmArgKeyRegex = regexp.MustCompile(`<arg_key>(.*?)</arg_key>`)
|
if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil {
|
||||||
glmArgValueRegex = regexp.MustCompile(`<arg_value>(.*?)</arg_value>`)
|
return api.ToolCall{}, err
|
||||||
)
|
}
|
||||||
|
|
||||||
// parseGLMToolCall parses a raw GLM tool call string into an api.ToolCall.
|
|
||||||
// The raw string has the format:
|
|
||||||
// {function-name}
|
|
||||||
// <arg_key>{arg-key-1}</arg_key>
|
|
||||||
// <arg_value>{arg-value-1}</arg_value>
|
|
||||||
// <arg_key>{arg-key-2}</arg_key>
|
|
||||||
// <arg_value>{arg-value-2}</arg_value>
|
|
||||||
// ...
|
|
||||||
func parseGLMToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
|
||||||
toolCall := api.ToolCall{}
|
toolCall := api.ToolCall{}
|
||||||
|
toolCall.Function = toolCallFunction
|
||||||
// Extract function name (first line or until first <)
|
|
||||||
functionNameMatch := glmFunctionNameRegex.FindStringSubmatch(raw.raw)
|
|
||||||
if len(functionNameMatch) < 2 {
|
|
||||||
return api.ToolCall{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
functionName := strings.TrimSpace(functionNameMatch[1])
|
|
||||||
toolCall.Function = api.ToolCallFunction{
|
|
||||||
Name: functionName,
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find the matching tool to get parameter types
|
|
||||||
var matchedTool *api.Tool
|
|
||||||
for i := range tools {
|
|
||||||
if tools[i].Function.Name == functionName {
|
|
||||||
matchedTool = &tools[i]
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract all arg_key and arg_value pairs
|
|
||||||
argKeys := glmArgKeyRegex.FindAllStringSubmatch(raw.raw, -1)
|
|
||||||
argValues := glmArgValueRegex.FindAllStringSubmatch(raw.raw, -1)
|
|
||||||
|
|
||||||
if len(argKeys) != len(argValues) {
|
|
||||||
slog.Warn("glm46 tool call has mismatched arg_key and arg_value counts", "keys", len(argKeys), "values", len(argValues))
|
|
||||||
}
|
|
||||||
|
|
||||||
toolCall.Function.Arguments = make(api.ToolCallFunctionArguments)
|
|
||||||
minLen := min(len(argKeys), len(argValues))
|
|
||||||
|
|
||||||
for i := 0; i < minLen; i++ {
|
|
||||||
if len(argKeys[i]) < 2 || len(argValues[i]) < 2 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
key := strings.TrimSpace(argKeys[i][1])
|
|
||||||
value := argValues[i][1]
|
|
||||||
|
|
||||||
// Trim leading and trailing newlines from value (following reference implementation)
|
|
||||||
value = strings.TrimPrefix(value, "\n")
|
|
||||||
value = strings.TrimSuffix(value, "\n")
|
|
||||||
|
|
||||||
// Look up the parameter type if we found the tool
|
|
||||||
var paramType api.PropertyType
|
|
||||||
if matchedTool != nil && matchedTool.Function.Parameters.Properties != nil {
|
|
||||||
if prop, ok := matchedTool.Function.Parameters.Properties[key]; ok {
|
|
||||||
paramType = prop.Type
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse the value according to its type
|
|
||||||
toolCall.Function.Arguments[key] = parseGLMValue(value, paramType)
|
|
||||||
}
|
|
||||||
|
|
||||||
return toolCall, nil
|
return toolCall, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// longest overlap between suffix of s and prefix of delim
|
|
||||||
func glmOverlap(s, delim string) int {
|
|
||||||
max := min(len(delim), len(s))
|
|
||||||
for i := max; i > 0; i-- {
|
|
||||||
if strings.HasSuffix(s, delim[:i]) {
|
|
||||||
return i
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func glmTrailingWhitespaceLen(s string) int {
|
|
||||||
remaining := s
|
|
||||||
total := 0
|
|
||||||
for len(remaining) > 0 {
|
|
||||||
r, size := utf8.DecodeLastRuneInString(remaining)
|
|
||||||
// if it's an invalid utf8 rune, assume it isn't whitespace
|
|
||||||
if r == utf8.RuneError && size == 1 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
if !unicode.IsSpace(r) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
total += size
|
|
||||||
remaining = remaining[:len(remaining)-size]
|
|
||||||
}
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGLMValue(raw string, paramType api.PropertyType) any {
|
|
||||||
// Check for null first (case-insensitive) - this takes precedence over any type
|
|
||||||
if strings.ToLower(raw) == "null" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If no type is specified, try to parse as JSON, otherwise return as string
|
|
||||||
if len(paramType) == 0 {
|
|
||||||
var val any
|
|
||||||
if err := json.Unmarshal([]byte(raw), &val); err == nil {
|
|
||||||
return val
|
|
||||||
}
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if any of the specified types match, using type precedence
|
|
||||||
// Order: boolean -> integer -> number -> array -> object -> string
|
|
||||||
typeSet := make(map[string]bool)
|
|
||||||
for _, t := range paramType {
|
|
||||||
typeSet[t] = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try boolean first (most restrictive)
|
|
||||||
if typeSet["boolean"] {
|
|
||||||
lower := strings.ToLower(raw)
|
|
||||||
switch lower {
|
|
||||||
case "true":
|
|
||||||
return true
|
|
||||||
case "false":
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
// If not a valid boolean but boolean is the only type, return false
|
|
||||||
if len(paramType) == 1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try parsing as JSON for complex types
|
|
||||||
var jsonVal any
|
|
||||||
if err := json.Unmarshal([]byte(raw), &jsonVal); err == nil {
|
|
||||||
// Check if the parsed type matches any of the expected types
|
|
||||||
switch v := jsonVal.(type) {
|
|
||||||
case float64:
|
|
||||||
if typeSet["number"] {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
if typeSet["integer"] && v == float64(int64(v)) {
|
|
||||||
return int64(v)
|
|
||||||
}
|
|
||||||
case bool:
|
|
||||||
if typeSet["boolean"] {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
case []any:
|
|
||||||
if typeSet["array"] {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
case map[string]any:
|
|
||||||
if typeSet["object"] {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
if typeSet["string"] {
|
|
||||||
return v
|
|
||||||
}
|
|
||||||
case nil:
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// If JSON parsed but type doesn't match, check if string is valid
|
|
||||||
if typeSet["string"] {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return the parsed JSON value as fallback
|
|
||||||
return jsonVal
|
|
||||||
}
|
|
||||||
|
|
||||||
// If JSON parsing failed but string is valid, return as string
|
|
||||||
if typeSet["string"] {
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to string
|
|
||||||
return raw
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ Hello, how are you?<|assistant|>`,
|
||||||
name: "basic with user assistant user",
|
name: "basic with user assistant user",
|
||||||
messages: []api.Message{
|
messages: []api.Message{
|
||||||
{Role: "user", Content: "What is the capital of France?"},
|
{Role: "user", Content: "What is the capital of France?"},
|
||||||
{Role: "assistant", Content: "The capital of France is Paris."},
|
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||||
{Role: "user", Content: "Fantastic!"},
|
{Role: "user", Content: "Fantastic!"},
|
||||||
},
|
},
|
||||||
expected: `[gMASK]<sop><|user|>
|
expected: `[gMASK]<sop><|user|>
|
||||||
|
|
@ -112,6 +112,15 @@ What is the weather like in Tokyo?<|assistant|>`,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "get_weather",
|
||||||
|
Arguments: api.ToolCallFunctionArguments{
|
||||||
|
"location": "Japan",
|
||||||
|
"unit": "fahrenheit",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -119,6 +128,11 @@ What is the weather like in Tokyo?<|assistant|>`,
|
||||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||||
ToolName: "get_weather",
|
ToolName: "get_weather",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
Role: "tool",
|
||||||
|
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||||
|
ToolName: "get_weather",
|
||||||
|
},
|
||||||
{
|
{
|
||||||
Role: "assistant",
|
Role: "assistant",
|
||||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||||
|
|
@ -173,9 +187,18 @@ What is the weather like in Tokyo?<|assistant|>
|
||||||
<arg_value>Tokyo, Japan</arg_value>
|
<arg_value>Tokyo, Japan</arg_value>
|
||||||
<arg_key>unit</arg_key>
|
<arg_key>unit</arg_key>
|
||||||
<arg_value>celsius</arg_value>
|
<arg_value>celsius</arg_value>
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>get_weather
|
||||||
|
<arg_key>location</arg_key>
|
||||||
|
<arg_value>Japan</arg_value>
|
||||||
|
<arg_key>unit</arg_key>
|
||||||
|
<arg_value>fahrenheit</arg_value>
|
||||||
</tool_call><|observation|>
|
</tool_call><|observation|>
|
||||||
<tool_response>
|
<tool_response>
|
||||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||||
|
</tool_response>
|
||||||
|
<tool_response>
|
||||||
|
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||||
</tool_response><|assistant|>
|
</tool_response><|assistant|>
|
||||||
<think></think>
|
<think></think>
|
||||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||||
|
|
@ -196,7 +219,7 @@ Hello, how are you?<|assistant|>`,
|
||||||
},
|
},
|
||||||
thinkValue: &api.ThinkValue{Value: false},
|
thinkValue: &api.ThinkValue{Value: false},
|
||||||
expected: `[gMASK]<sop><|user|>
|
expected: `[gMASK]<sop><|user|>
|
||||||
Hello, how are you?<|assistant|>
|
Hello, how are you?/nothink<|assistant|>
|
||||||
<think></think>`,
|
<think></think>`,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue