From 1fe7e07f6334e530a6409831e4f2b3d00dfd1a75 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Wed, 3 Sep 2025 15:04:20 -0700 Subject: [PATCH] sampler/runner: enable gpt-oss structured outputs --- harmony/harmonyparser.go | 14 ++++++++------ runner/ollamarunner/runner.go | 12 +++++++++++- sample/samplers.go | 4 ++++ 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/harmony/harmonyparser.go b/harmony/harmonyparser.go index 93e8dd6ae..ae071932e 100644 --- a/harmony/harmonyparser.go +++ b/harmony/harmonyparser.go @@ -48,12 +48,13 @@ func (s harmonyParserState) String() string { } type HarmonyParser struct { - state harmonyParserState - MessageStartTag string - MessageEndTag string - HeaderEndTag string - acc strings.Builder - lifetimeAcc strings.Builder + state harmonyParserState + MessageStartTag string + MessageEndTag string + HeaderEndTag string + ConstrainAllowed bool + acc strings.Builder + lifetimeAcc strings.Builder } type HarmonyEvent interface { @@ -328,6 +329,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo } case "final": h.state = harmonyMessageState_Normal + h.HarmonyParser.ConstrainAllowed = true } case HarmonyEventContentEmitted: logutil.Trace("harmony event content", "content", event.Content, "state", h.state) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index fc57877ea..4594f070e 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -814,7 +814,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { req.Options.TopP, req.Options.MinP, req.Options.Seed, - grammar, + nil, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ @@ -865,6 +865,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + // TODO(parthsareen): generalize grammar enablement on the fly for all thinking models + if harmonyMessageHandler == nil { + seq.sampler.SetGrammar(grammar) + } + + grammarSet := false for { select { case <-r.Context().Done(): @@ -877,6 +883,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var toolContent string content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser) harmonyToolParser.Add(toolContent) + if harmonyMessageHandler.HarmonyParser.ConstrainAllowed && !grammarSet { + seq.sampler.SetGrammar(grammar) + grammarSet = true + } } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ diff --git a/sample/samplers.go b/sample/samplers.go index d395650d9..1c913be9f 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -25,6 +25,10 @@ type Sampler struct { grammar *GrammarSampler } +func (s *Sampler) SetGrammar(grammar *GrammarSampler) { + s.grammar = grammar +} + func (s *Sampler) Sample(logits []float32) (int32, error) { if len(logits) == 0 { return -1, errors.New("sample: no logits provided to sample")