Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
04314765f2 llm: consider null format same as empty value 2024-12-17 09:16:01 -08:00
6 changed files with 37 additions and 82 deletions

View File

@@ -700,24 +700,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
if len(req.Format) > 0 {
switch string(req.Format) {
case `null`, `""`:
// Field was set, but "missing" a value. We accept
// these as "not set".
break
case `"json"`:
switch {
case bytes.Equal(req.Format, []byte(`""`)) || bytes.Equal(req.Format, []byte(`null`)):
// fallthrough
case bytes.Equal(req.Format, []byte(`"json"`)):
request["grammar"] = grammarJSON
default:
if req.Format[0] != '{' {
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
}
case bytes.HasPrefix(req.Format, []byte("{")):
// User provided a JSON schema
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
request["grammar"] = string(g)
default:
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema", req.Format)
}
}

View File

@@ -39,34 +39,25 @@ func TestLLMServerCompletionFormat(t *testing.T) {
cancel() // prevent further processing if request makes it past the format check
checkValid := func(err error) {
checkCanceled := func(err error) {
t.Helper()
if !errors.Is(err, context.Canceled) {
t.Fatalf("Completion: err = %v; expected context.Canceled", err)
}
}
valids := []string{
// "missing"
``,
`""`,
`null`,
// JSON
`"json"`,
`{"type":"object"}`,
}
valids := []string{`"json"`, `{"type":"object"}`, ``, `""`, `null`}
for _, valid := range valids {
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: []byte(valid),
}, nil)
checkValid(err)
checkCanceled(err)
}
err := s.Completion(ctx, CompletionRequest{
Options: new(api.Options),
Format: nil, // missing format
}, nil)
checkValid(err)
checkCanceled(err)
}

View File

@@ -302,7 +302,7 @@ func parseObjects(s string) []map[string]any {
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Sub(func(n parse.Node) bool {
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
@@ -315,7 +315,7 @@ func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
}
var b bytes.Buffer
if err := tmpl.Template().Execute(&b, map[string][]api.ToolCall{
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{

View File

@@ -518,8 +518,8 @@ func TestCreateTemplateSystem(t *testing.T) {
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
}

View File

@@ -93,8 +93,8 @@ func Named(s string) (*named, error) {
var DefaultTemplate, _ = Parse("{{ .Prompt }}")
type Template struct {
tree *parse.Tree
raw string
*template.Template
raw string
}
// response is a template node that can be added to templates that don't already have one
@@ -124,18 +124,17 @@ var funcs = template.FuncMap{
}
func Parse(s string) (*Template, error) {
tree := parse.New("")
tree.Mode = tree.Mode | parse.SkipFuncCheck
tmpl := template.New("").Option("missingkey=zero").Funcs(funcs)
tree, err := tree.Parse(s, "", "", map[string]*parse.Tree{})
tmpl, err := tmpl.Parse(s)
if err != nil {
return nil, err
}
t := Template{tree, s}
t := Template{Template: tmpl, raw: s}
if vars := t.Vars(); !slices.Contains(vars, "messages") && !slices.Contains(vars, "response") {
// touch up the template and append {{ .Response }}
t.tree.Root.Nodes = append(t.tree.Root.Nodes, &response)
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, &response)
}
return &t, nil
@@ -147,8 +146,10 @@ func (t *Template) String() string {
func (t *Template) Vars() []string {
var vars []string
for _, n := range t.tree.Root.Nodes {
vars = append(vars, Identifiers(n)...)
for _, tt := range t.Templates() {
for _, n := range tt.Root.Nodes {
vars = append(vars, Identifiers(n)...)
}
}
set := make(map[string]struct{})
@@ -171,8 +172,7 @@ type Values struct {
forceLegacy bool
}
// Sub returns a new template with the subtree that matches the predicate
func (t *Template) Sub(fn func(parse.Node) bool) *Template {
func (t *Template) Subtree(fn func(parse.Node) bool) *template.Template {
var walk func(parse.Node) parse.Node
walk = func(n parse.Node) parse.Node {
if fn(n) {
@@ -205,34 +205,29 @@ func (t *Template) Sub(fn func(parse.Node) bool) *Template {
return nil
}
if n := walk(t.tree.Root); n != nil {
return &Template{
tree: &parse.Tree{
if n := walk(t.Tree.Root); n != nil {
return (&template.Template{
Tree: &parse.Tree{
Root: &parse.ListNode{
Nodes: []parse.Node{n},
},
},
}
}).Funcs(funcs)
}
return nil
}
func (t *Template) Template() *template.Template {
return template.Must(template.New("").Option("missingkey=zero").Funcs(funcs).AddParseTree("", t.tree))
}
func (t *Template) Execute(w io.Writer, v Values) error {
tmpl := t.Template()
system, messages := collate(v.Messages)
if v.Prompt != "" && v.Suffix != "" {
return tmpl.Execute(w, map[string]any{
return t.Template.Execute(w, map[string]any{
"Prompt": v.Prompt,
"Suffix": v.Suffix,
"Response": "",
})
} else if !v.forceLegacy && slices.Contains(t.Vars(), "messages") {
return tmpl.Execute(w, map[string]any{
return t.Template.Execute(w, map[string]any{
"System": system,
"Messages": messages,
"Tools": v.Tools,
@@ -245,7 +240,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
var prompt, response string
for _, m := range messages {
execute := func() error {
if err := tmpl.Execute(&b, map[string]any{
if err := t.Template.Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
@@ -280,7 +275,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
}
var cut bool
nodes := deleteNode(t.tree.Root.Copy(), func(n parse.Node) bool {
nodes := deleteNode(t.Template.Root.Copy(), func(n parse.Node) bool {
if field, ok := n.(*parse.FieldNode); ok && slices.Contains(field.Ident, "Response") {
cut = true
return false
@@ -290,7 +285,7 @@ func (t *Template) Execute(w io.Writer, v Values) error {
})
tree := parse.Tree{Root: nodes.(*parse.ListNode)}
if err := template.Must(tmpl.AddParseTree("", &tree)).Execute(&b, map[string]any{
if err := template.Must(template.New("").AddParseTree("", &tree)).Execute(&b, map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,

View File

@@ -54,7 +54,7 @@ func TestNamed(t *testing.T) {
t.Fatal(err)
}
if tmpl.tree.Root.String() == "" {
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
@@ -153,7 +153,7 @@ func TestTemplate(t *testing.T) {
}
}
func TestParseVars(t *testing.T) {
func TestParse(t *testing.T) {
cases := []struct {
template string
vars []string
@@ -181,9 +181,6 @@ func TestParseVars(t *testing.T) {
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
{"{{ json .Messages }}", []string{"messages"}},
// undefined functions should not error
{"{{ undefined }}", []string{"response"}},
}
for _, tt := range cases {
@@ -200,30 +197,6 @@ func TestParseVars(t *testing.T) {
}
}
func TestParseExecute(t *testing.T) {
t.Run("undefined function", func(t *testing.T) {
tmpl, err := Parse(`{{- if .Suffix }}{{ .Prompt }} {{ .Suffix }}{{- else }}{{ undefined }}{{- end }}`)
if err != nil {
t.Fatal(err)
}
var b bytes.Buffer
if err := tmpl.Execute(&b, Values{Prompt: "def add(", Suffix: " return c"}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(b.String(), "def add( return c"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if err := tmpl.Execute(io.Discard, Values{}); err == nil {
t.Fatal("expected error")
} else if !strings.Contains(err.Error(), "\"undefined\" is not a defined function") {
t.Fatal(err)
}
})
}
func TestExecuteWithMessages(t *testing.T) {
type template struct {
name string