diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index 3ec2c21f1..62bc8a47a 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -47,12 +47,13 @@ func (s harmonyParserState) String() string { } type HarmonyParser struct { - state harmonyParserState - MessageStartTag string - MessageEndTag string - HeaderEndTag string - acc strings.Builder - lifetimeAcc strings.Builder + state harmonyParserState + MessageStartTag string + MessageEndTag string + HeaderEndTag string + constraintsAllowed bool + acc strings.Builder + lifetimeAcc strings.Builder } type HarmonyEvent interface { @@ -89,6 +90,10 @@ func (s *HarmonyParser) AddImplicitStart() { s.acc.WriteString("<|start|>assistant") } +func (s *HarmonyParser) ConstraintsAllowed() bool { + return s.constraintsAllowed +} + func Prefill(lastMessage api.Message) string { if lastMessage.Role != "assistant" { return "" @@ -341,6 +346,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri } case "final": h.state = harmonyMessageState_Normal + h.HarmonyParser.constraintsAllowed = true } case HarmonyEventContentEmitted: logutil.Trace("harmony event content", "content", event.Content, "state", h.state) diff --git a/parser/token_parser.go b/parser/token_parser.go index 812458299..a889a19fd 100644 --- a/parser/token_parser.go +++ b/parser/token_parser.go @@ -33,6 +33,7 @@ type MessageHandler interface { type ParserInternals interface { AddImplicitStartOrPrefill(prefillString string) + ConstraintsAllowed() bool } type ToolParser interface { @@ -51,6 +52,10 @@ type defaultEngine struct{} func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {} +func (defaultEngine) ConstraintsAllowed() bool { + return true +} + type defaultToolParser struct{} func (defaultToolParser) Add(token string) {} @@ -104,6 +109,10 @@ func (p *TokenParser) repeatLimitReached(token string) bool { return p.tokenRepeat >= p.repeatLimit } +func (p *TokenParser) ConstraintsAllowed() bool { + return p.parserEngine.ConstraintsAllowed() +} + // TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level func (p *TokenParser) Drain() []api.ToolCall { toolName, toolContent := p.toolParser.Drain() diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 201d55a16..e6b1a6df4 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -782,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) - if req.Options == nil { opts := api.DefaultOptions() req.Options = &opts @@ -816,7 +814,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { req.Options.TopP, req.Options.MinP, req.Options.Seed, - grammar, + nil, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ @@ -831,6 +829,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) + switch req.ParserType { + case parser.TokenParserTypeHarmony: + // Do not set grammar until model allows constraining + default: + seq.sampler.SetGrammar(grammar) + } + // Ensure there is a place to put the sequence, released when removed from s.seqs if err := s.seqsSem.Acquire(r.Context(), 1); err != nil { if errors.Is(err, context.Canceled) { @@ -867,6 +873,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + grammarSet := false for { select { case <-r.Context().Done(): @@ -883,6 +890,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + if !grammarSet && grammar != nil && tokenParser.ConstraintsAllowed() { + seq.sampler.SetGrammar(grammar) + grammarSet = true + } + if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ Content: content, Thinking: thinking, diff --git a/sample/samplers.go b/sample/samplers.go index d395650d9..1c913be9f 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -25,6 +25,10 @@ type Sampler struct { grammar *GrammarSampler } +func (s *Sampler) SetGrammar(grammar *GrammarSampler) { + s.grammar = grammar +} + func (s *Sampler) Sample(logits []float32) (int32, error) { if len(logits) == 0 { return -1, errors.New("sample: no logits provided to sample")