diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 81d41f2ab..2b16dc62e 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index c2a526080..631baeccd 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel { eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06), ropeLocalBase: c.Float("rope.local.freq_base", 10000.0), ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0), - ropeScale: c.Float("rope.scaling.factor", 1.0), + ropeScale: 1, + // NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights + // (8 instead of 1) + // ropeScale: c.Float("rope.scaling.factor", 1.0), }, } @@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextConfig.ropeGlobalBase } - return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil + return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil } type TextMLP struct { diff --git a/model/parsers/qwen3coder.go b/model/parsers/qwen3coder.go index b0e8ec48c..a7838fb78 100644 --- a/model/parsers/qwen3coder.go +++ b/model/parsers/qwen3coder.go @@ -393,18 +393,55 @@ func parseValue(raw string, paramType api.PropertyType) any { return raw } -var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) +var ( + qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`) + qwenXMLTagRegex = regexp.MustCompile(``) +) // transformToXML transforms a raw qwen tool call with xml-like tags into valid // xml so that it can be parsed by any xml parser func transformToXML(raw string) string { // take the form `` and transform it to ``, taking // care to properly escape the string that becomes the attribute value - return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { + transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string { groups := qwenTagRegex.FindStringSubmatch(match) tag := groups[1] var escapedValue strings.Builder xml.EscapeText(&escapedValue, []byte(groups[2])) return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String()) }) + + // Walk the resulting string, escaping any character data that sits between the + // xml tags we just emitted + var out strings.Builder + lastIdx := 0 + for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) { + if loc[0] > lastIdx { + escapeTextNode(&out, transformed[lastIdx:loc[0]]) + } + out.WriteString(transformed[loc[0]:loc[1]]) + lastIdx = loc[1] + } + if lastIdx < len(transformed) { + escapeTextNode(&out, transformed[lastIdx:]) + } + + return out.String() +} + +// escapeTextNode escapes XML character data without altering other characters +// like newlines or tabs (which is why we don't use xml.EscapeText for this) +func escapeTextNode(sb *strings.Builder, s string) { + for _, r := range s { + switch r { + case '&': + sb.WriteString("&") + case '<': + sb.WriteString("<") + case '>': + sb.WriteString(">") + default: + sb.WriteRune(r) + } + } } diff --git a/model/parsers/qwen3coder_test.go b/model/parsers/qwen3coder_test.go index 2389c77b5..43823e6fc 100644 --- a/model/parsers/qwen3coder_test.go +++ b/model/parsers/qwen3coder_test.go @@ -312,6 +312,41 @@ true }, }, }, + // regression test for + { + name: "ampersands in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "done" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"done\"", + }, + }, + }, + }, + { + name: "angle brackets in parameter values", + tools: []api.Tool{}, + rawToolCall: ` + +ls && echo "a > b and a < b" + +`, + wantToolCall: api.ToolCall{ + Function: api.ToolCallFunction{ + Name: "exec", + Arguments: map[string]any{ + "command": "ls && echo \"a > b and a < b\"", + }, + }, + }, + }, } for i, step := range steps { @@ -798,6 +833,19 @@ celsius `, }, + { + desc: "ampersands in parameter values", + raw: ` + + San Francisco & San Jose + + `, + want: ` + + San Francisco & San Jose + + `, + }, } for _, tc := range cases {