Merge 0b7d8eda68 into 76912c062a
This commit is contained in:
commit
57920300fe
|
|
@ -440,9 +440,20 @@ func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinki
|
|||
if toolName != nil {
|
||||
name := strings.TrimPrefix(*toolName, "functions.")
|
||||
name = h.FunctionNameMap.OriginalFromConverted(name)
|
||||
|
||||
// Extract JSON arguments from raw content
|
||||
// The model may include thinking/analysis segments before JSON:
|
||||
// e.g., "<|channel|>analysis<|message|>thinking text{...}"
|
||||
jsonContent, extractedThinking := extractToolCallJSON(raw)
|
||||
|
||||
// Add any extracted thinking to the thinking output
|
||||
if extractedThinking != "" {
|
||||
thinking += extractedThinking
|
||||
}
|
||||
|
||||
var args api.ToolCallFunctionArguments
|
||||
if err := json.Unmarshal([]byte(raw), &args); err != nil {
|
||||
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', err=%w", raw, err)
|
||||
if err := json.Unmarshal([]byte(jsonContent), &args); err != nil {
|
||||
return "", "", nil, fmt.Errorf("error parsing tool call: raw='%s', extracted='%s', err=%w", raw, jsonContent, err)
|
||||
}
|
||||
calls = append(calls, api.ToolCall{Function: api.ToolCallFunction{Name: name, Arguments: args}})
|
||||
}
|
||||
|
|
@ -451,6 +462,126 @@ func (h *HarmonyMessageHandler) Add(s string, done bool) (content string, thinki
|
|||
return content, thinking, calls, nil
|
||||
}
|
||||
|
||||
// extractToolCallJSON extracts JSON arguments from tool content that may contain
|
||||
// embedded harmony segments. The model sometimes outputs thinking within tool calls:
|
||||
// "<|channel|>analysis<|message|>I need to list files.{"path": "/tmp"}"
|
||||
// This function extracts the JSON object and returns any thinking text separately.
|
||||
func extractToolCallJSON(raw string) (jsonStr string, thinking string) {
|
||||
// If content starts with valid JSON, return as-is
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if len(trimmed) > 0 && trimmed[0] == '{' {
|
||||
return extractCompleteJSON(trimmed), ""
|
||||
}
|
||||
|
||||
// Strip any leading <|channel|>...<|message|> segments (embedded thinking)
|
||||
stripped := raw
|
||||
var thinkingParts []string
|
||||
|
||||
for {
|
||||
// Check for channel tag at or near the start
|
||||
channelIdx := strings.Index(stripped, "<|channel|>")
|
||||
if channelIdx != 0 && channelIdx != -1 {
|
||||
// Channel tag not at start - check if it's preceded only by whitespace
|
||||
prefix := strings.TrimSpace(stripped[:channelIdx])
|
||||
if prefix != "" {
|
||||
break // Non-whitespace content before channel tag
|
||||
}
|
||||
}
|
||||
if channelIdx == -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// Find the corresponding <|message|> tag
|
||||
messageIdx := strings.Index(stripped[channelIdx:], "<|message|>")
|
||||
if messageIdx == -1 {
|
||||
break
|
||||
}
|
||||
messageIdx += channelIdx
|
||||
|
||||
// Extract content after <|message|>
|
||||
afterMessage := stripped[messageIdx+len("<|message|>"):]
|
||||
|
||||
// Find where the JSON starts (or another segment)
|
||||
jsonStart := strings.Index(afterMessage, "{")
|
||||
nextChannel := strings.Index(afterMessage, "<|channel|>")
|
||||
|
||||
if jsonStart == -1 {
|
||||
// No JSON found, this might all be thinking
|
||||
thinkingParts = append(thinkingParts, strings.TrimSpace(afterMessage))
|
||||
break
|
||||
}
|
||||
|
||||
// If there's another channel tag before JSON, extract thinking up to it
|
||||
if nextChannel != -1 && nextChannel < jsonStart {
|
||||
thinkingParts = append(thinkingParts, strings.TrimSpace(afterMessage[:nextChannel]))
|
||||
stripped = afterMessage[nextChannel:]
|
||||
continue
|
||||
}
|
||||
|
||||
// Extract thinking (text before JSON) and the JSON
|
||||
if jsonStart > 0 {
|
||||
thinkingParts = append(thinkingParts, strings.TrimSpace(afterMessage[:jsonStart]))
|
||||
}
|
||||
stripped = afterMessage[jsonStart:]
|
||||
break
|
||||
}
|
||||
|
||||
// Build thinking string
|
||||
thinking = strings.Join(thinkingParts, " ")
|
||||
|
||||
// Extract complete JSON object from remaining content
|
||||
jsonStr = extractCompleteJSON(stripped)
|
||||
|
||||
return jsonStr, thinking
|
||||
}
|
||||
|
||||
// extractCompleteJSON finds and extracts a complete JSON object from a string
|
||||
func extractCompleteJSON(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if len(s) == 0 || s[0] != '{' {
|
||||
// Try to find JSON object start
|
||||
idx := strings.Index(s, "{")
|
||||
if idx == -1 {
|
||||
return s // Return as-is, will fail JSON parsing with clear error
|
||||
}
|
||||
s = s[idx:]
|
||||
}
|
||||
|
||||
// Match braces to find complete JSON object
|
||||
depth := 0
|
||||
inString := false
|
||||
escaped := false
|
||||
|
||||
for i, ch := range s {
|
||||
if escaped {
|
||||
escaped = false
|
||||
continue
|
||||
}
|
||||
if ch == '\\' && inString {
|
||||
escaped = true
|
||||
continue
|
||||
}
|
||||
if ch == '"' {
|
||||
inString = !inString
|
||||
continue
|
||||
}
|
||||
if inString {
|
||||
continue
|
||||
}
|
||||
if ch == '{' {
|
||||
depth++
|
||||
} else if ch == '}' {
|
||||
depth--
|
||||
if depth == 0 {
|
||||
return s[:i+1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Couldn't find matching braces, return what we have
|
||||
return s
|
||||
}
|
||||
|
||||
// HasToolSupport implements the Parser interface
|
||||
func (h *HarmonyMessageHandler) HasToolSupport() bool {
|
||||
return true
|
||||
|
|
|
|||
|
|
@ -536,3 +536,85 @@ func TestFunctionConvertAndAdd(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractToolCallJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
raw string
|
||||
wantJSON string
|
||||
wantThinking string
|
||||
}{
|
||||
{
|
||||
name: "clean JSON",
|
||||
raw: `{"path": "/tmp"}`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "",
|
||||
},
|
||||
{
|
||||
name: "JSON with leading whitespace",
|
||||
raw: ` {"path": "/tmp"}`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "",
|
||||
},
|
||||
{
|
||||
name: "analysis segment before JSON",
|
||||
raw: `<|channel|>analysis<|message|>We need to list files.{"path": "/tmp"}`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "We need to list files.",
|
||||
},
|
||||
{
|
||||
name: "analysis segment with just thinking",
|
||||
raw: `<|channel|>analysis<|message|>Let me check{"arg": "value"}`,
|
||||
wantJSON: `{"arg": "value"}`,
|
||||
wantThinking: "Let me check",
|
||||
},
|
||||
{
|
||||
name: "multiple analysis segments",
|
||||
raw: `<|channel|>analysis<|message|>First thought<|channel|>analysis<|message|>Second thought{"path": "/home"}`,
|
||||
wantJSON: `{"path": "/home"}`,
|
||||
wantThinking: "First thought Second thought",
|
||||
},
|
||||
{
|
||||
name: "nested JSON objects",
|
||||
raw: `<|channel|>analysis<|message|>Thinking{"nested": {"key": "value"}, "arr": [1, 2]}`,
|
||||
wantJSON: `{"nested": {"key": "value"}, "arr": [1, 2]}`,
|
||||
wantThinking: "Thinking",
|
||||
},
|
||||
{
|
||||
name: "JSON with escaped quotes",
|
||||
raw: `<|channel|>analysis<|message|>Think{"message": "Hello \"world\""}`,
|
||||
wantJSON: `{"message": "Hello \"world\""}`,
|
||||
wantThinking: "Think",
|
||||
},
|
||||
{
|
||||
name: "whitespace before channel tag",
|
||||
raw: ` <|channel|>analysis<|message|>Thinking{"path": "/tmp"}`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "Thinking",
|
||||
},
|
||||
{
|
||||
name: "content before channel tag (extracts JSON anyway)",
|
||||
raw: `some text<|channel|>analysis<|message|>{"path": "/tmp"}`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "",
|
||||
},
|
||||
{
|
||||
name: "JSON with trailing content",
|
||||
raw: `{"path": "/tmp"}<|call|>`,
|
||||
wantJSON: `{"path": "/tmp"}`,
|
||||
wantThinking: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotJSON, gotThinking := extractToolCallJSON(tt.raw)
|
||||
if gotJSON != tt.wantJSON {
|
||||
t.Errorf("extractToolCallJSON() JSON = %q, want %q", gotJSON, tt.wantJSON)
|
||||
}
|
||||
if gotThinking != tt.wantThinking {
|
||||
t.Errorf("extractToolCallJSON() thinking = %q, want %q", gotThinking, tt.wantThinking)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue