Merge remote-tracking branch 'upstream/main' into vulkanV3
This commit is contained in:
commit
66bdd882f5
|
|
@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
|
|
|||
|
|
@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
|
|||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
ropeScale: 1,
|
||||
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||
// (8 instead of 1)
|
||||
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
|||
ropeBase = m.TextConfig.ropeGlobalBase
|
||||
}
|
||||
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
|
|
|||
|
|
@ -393,18 +393,55 @@ func parseValue(raw string, paramType api.PropertyType) any {
|
|||
return raw
|
||||
}
|
||||
|
||||
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||
var (
|
||||
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||
)
|
||||
|
||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||
// xml so that it can be parsed by any xml parser
|
||||
func transformToXML(raw string) string {
|
||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||
// care to properly escape the string that becomes the attribute value
|
||||
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||
tag := groups[1]
|
||||
var escapedValue strings.Builder
|
||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||
})
|
||||
|
||||
// Walk the resulting string, escaping any character data that sits between the
|
||||
// xml tags we just emitted
|
||||
var out strings.Builder
|
||||
lastIdx := 0
|
||||
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||
if loc[0] > lastIdx {
|
||||
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||
}
|
||||
out.WriteString(transformed[loc[0]:loc[1]])
|
||||
lastIdx = loc[1]
|
||||
}
|
||||
if lastIdx < len(transformed) {
|
||||
escapeTextNode(&out, transformed[lastIdx:])
|
||||
}
|
||||
|
||||
return out.String()
|
||||
}
|
||||
|
||||
// escapeTextNode escapes XML character data without altering other characters
|
||||
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||
func escapeTextNode(sb *strings.Builder, s string) {
|
||||
for _, r := range s {
|
||||
switch r {
|
||||
case '&':
|
||||
sb.WriteString("&")
|
||||
case '<':
|
||||
sb.WriteString("<")
|
||||
case '>':
|
||||
sb.WriteString(">")
|
||||
default:
|
||||
sb.WriteRune(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -312,6 +312,41 @@ true
|
|||
},
|
||||
},
|
||||
},
|
||||
// regression test for <https://github.com/ollama/ollama/issues/12357>
|
||||
{
|
||||
name: "ampersands in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=exec>
|
||||
<parameter=command>
|
||||
ls && echo "done"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"done\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "angle brackets in parameter values",
|
||||
tools: []api.Tool{},
|
||||
rawToolCall: `<function=exec>
|
||||
<parameter=command>
|
||||
ls && echo "a > b and a < b"
|
||||
</parameter>
|
||||
</function>`,
|
||||
wantToolCall: api.ToolCall{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "exec",
|
||||
Arguments: map[string]any{
|
||||
"command": "ls && echo \"a > b and a < b\"",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i, step := range steps {
|
||||
|
|
@ -798,6 +833,19 @@ celsius
|
|||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
{
|
||||
desc: "ampersands in parameter values",
|
||||
raw: `<function=get_current_temperature>
|
||||
<parameter=location>
|
||||
San Francisco & San Jose
|
||||
</parameter>
|
||||
</function>`,
|
||||
want: `<function name="get_current_temperature">
|
||||
<parameter name="location">
|
||||
San Francisco & San Jose
|
||||
</parameter>
|
||||
</function>`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
|
|
|
|||
Loading…
Reference in New Issue