Compare commits
4 Commits
v0.12.0-rc
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4ef2b2852d | ||
|
|
3677842ff1 | ||
|
|
242df70a75 | ||
|
|
dba39b2eee |
@@ -128,7 +128,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||||
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.Options.attnKeyLen, m.Options.ropeBase, 1/m.Options.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type MLP struct {
|
type MLP struct {
|
||||||
|
|||||||
@@ -53,7 +53,10 @@ func newTextModel(c fs.Config) *TextModel {
|
|||||||
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
|
||||||
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
|
||||||
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
|
||||||
ropeScale: c.Float("rope.scaling.factor", 1.0),
|
ropeScale: 1,
|
||||||
|
// NOTE: the rope.scaling.factor is set incorrectly in the official QAT weights
|
||||||
|
// (8 instead of 1)
|
||||||
|
// ropeScale: c.Float("rope.scaling.factor", 1.0),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -113,7 +116,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
|||||||
ropeBase = m.TextConfig.ropeGlobalBase
|
ropeBase = m.TextConfig.ropeGlobalBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
return fast.RoPE(ctx, key, shift, m.TextConfig.attnKeyLen, ropeBase, 1/m.TextConfig.ropeScale, rope.WithTypeNeoX()), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TextMLP struct {
|
type TextMLP struct {
|
||||||
|
|||||||
@@ -393,18 +393,55 @@ func parseValue(raw string, paramType api.PropertyType) any {
|
|||||||
return raw
|
return raw
|
||||||
}
|
}
|
||||||
|
|
||||||
var qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
var (
|
||||||
|
qwenTagRegex = regexp.MustCompile(`<(\w+)=([^>]+)>`)
|
||||||
|
qwenXMLTagRegex = regexp.MustCompile(`</?(?:function|parameter)(?:\s+name="[^"]*")?>`)
|
||||||
|
)
|
||||||
|
|
||||||
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
// transformToXML transforms a raw qwen tool call with xml-like tags into valid
|
||||||
// xml so that it can be parsed by any xml parser
|
// xml so that it can be parsed by any xml parser
|
||||||
func transformToXML(raw string) string {
|
func transformToXML(raw string) string {
|
||||||
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
// take the form `<tag=abc>` and transform it to `<tag name="abc">`, taking
|
||||||
// care to properly escape the string that becomes the attribute value
|
// care to properly escape the string that becomes the attribute value
|
||||||
return qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
transformed := qwenTagRegex.ReplaceAllStringFunc(raw, func(match string) string {
|
||||||
groups := qwenTagRegex.FindStringSubmatch(match)
|
groups := qwenTagRegex.FindStringSubmatch(match)
|
||||||
tag := groups[1]
|
tag := groups[1]
|
||||||
var escapedValue strings.Builder
|
var escapedValue strings.Builder
|
||||||
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
xml.EscapeText(&escapedValue, []byte(groups[2]))
|
||||||
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
return fmt.Sprintf(`<%s name="%s">`, tag, escapedValue.String())
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Walk the resulting string, escaping any character data that sits between the
|
||||||
|
// xml tags we just emitted
|
||||||
|
var out strings.Builder
|
||||||
|
lastIdx := 0
|
||||||
|
for _, loc := range qwenXMLTagRegex.FindAllStringIndex(transformed, -1) {
|
||||||
|
if loc[0] > lastIdx {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:loc[0]])
|
||||||
|
}
|
||||||
|
out.WriteString(transformed[loc[0]:loc[1]])
|
||||||
|
lastIdx = loc[1]
|
||||||
|
}
|
||||||
|
if lastIdx < len(transformed) {
|
||||||
|
escapeTextNode(&out, transformed[lastIdx:])
|
||||||
|
}
|
||||||
|
|
||||||
|
return out.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// escapeTextNode escapes XML character data without altering other characters
|
||||||
|
// like newlines or tabs (which is why we don't use xml.EscapeText for this)
|
||||||
|
func escapeTextNode(sb *strings.Builder, s string) {
|
||||||
|
for _, r := range s {
|
||||||
|
switch r {
|
||||||
|
case '&':
|
||||||
|
sb.WriteString("&")
|
||||||
|
case '<':
|
||||||
|
sb.WriteString("<")
|
||||||
|
case '>':
|
||||||
|
sb.WriteString(">")
|
||||||
|
default:
|
||||||
|
sb.WriteRune(r)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -312,6 +312,41 @@ true
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
// regression test for <https://github.com/ollama/ollama/issues/12357>
|
||||||
|
{
|
||||||
|
name: "ampersands in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "done"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"done\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "angle brackets in parameter values",
|
||||||
|
tools: []api.Tool{},
|
||||||
|
rawToolCall: `<function=exec>
|
||||||
|
<parameter=command>
|
||||||
|
ls && echo "a > b and a < b"
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
wantToolCall: api.ToolCall{
|
||||||
|
Function: api.ToolCallFunction{
|
||||||
|
Name: "exec",
|
||||||
|
Arguments: map[string]any{
|
||||||
|
"command": "ls && echo \"a > b and a < b\"",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, step := range steps {
|
for i, step := range steps {
|
||||||
@@ -798,6 +833,19 @@ celsius
|
|||||||
</parameter>
|
</parameter>
|
||||||
</function>`,
|
</function>`,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
desc: "ampersands in parameter values",
|
||||||
|
raw: `<function=get_current_temperature>
|
||||||
|
<parameter=location>
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
want: `<function name="get_current_temperature">
|
||||||
|
<parameter name="location">
|
||||||
|
San Francisco & San Jose
|
||||||
|
</parameter>
|
||||||
|
</function>`,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ import (
|
|||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/auth"
|
|
||||||
"github.com/ollama/ollama/discover"
|
"github.com/ollama/ollama/discover"
|
||||||
"github.com/ollama/ollama/envconfig"
|
"github.com/ollama/ollama/envconfig"
|
||||||
"github.com/ollama/ollama/format"
|
"github.com/ollama/ollama/format"
|
||||||
@@ -251,18 +250,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Generate(c, &req, fn)
|
err = client.Generate(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &authError) {
|
||||||
pk, pkErr := auth.GetPublicKey()
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
|
||||||
if pkErr != nil {
|
return
|
||||||
slog.Error("couldn't get public key", "error", pkErr)
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
var apiError api.StatusError
|
||||||
return
|
if errors.As(err, &apiError) {
|
||||||
}
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": "unauthorized",
|
|
||||||
"public_key": pk,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
@@ -1813,18 +1808,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
|||||||
client := api.NewClient(remoteURL, http.DefaultClient)
|
client := api.NewClient(remoteURL, http.DefaultClient)
|
||||||
err = client.Chat(c, &req, fn)
|
err = client.Chat(c, &req, fn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
var sErr api.AuthorizationError
|
var authError api.AuthorizationError
|
||||||
if errors.As(err, &sErr) && sErr.StatusCode == http.StatusUnauthorized {
|
if errors.As(err, &authError) {
|
||||||
pk, pkErr := auth.GetPublicKey()
|
c.JSON(authError.StatusCode, gin.H{"error": "unauthorized", "public_key": authError.PublicKey})
|
||||||
if pkErr != nil {
|
return
|
||||||
slog.Error("couldn't get public key", "error", pkErr)
|
}
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "error getting public key"})
|
var apiError api.StatusError
|
||||||
return
|
if errors.As(err, &apiError) {
|
||||||
}
|
c.JSON(apiError.StatusCode, apiError)
|
||||||
c.JSON(http.StatusUnauthorized, gin.H{
|
|
||||||
"error": "unauthorized",
|
|
||||||
"public_key": pk,
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||||
|
|||||||
Reference in New Issue
Block a user