diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index e6b1a6df4..56d553eee 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -62,6 +62,11 @@ type Sequence struct { // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string + // startGate + startGate *sync.Mutex + + grammarReady bool + // input cache being used by this sequence cache *InputCacheSlot @@ -164,6 +169,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe // TODO(jessegross): Ingest cached history for grammar + startGate := &sync.Mutex{} return &Sequence{ ctxs: ctxs, mmStore: mmStore, @@ -179,6 +185,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, + startGate: startGate, + grammarReady: false, }, nil } @@ -707,11 +715,18 @@ func (s *Server) computeBatch(activeBatch batchState) { // sample a token vocabSize := len(outputs) / len(activeBatch.batch.Outputs) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) + + if !seq.grammarReady { + seq.startGate.Lock() + } token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) if err != nil { s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) return } + if !seq.grammarReady { + seq.startGate.Unlock() + } nextBatchTokens[i].Token = token @@ -830,11 +845,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) - switch req.ParserType { - case parser.TokenParserTypeHarmony: - // Do not set grammar until model allows constraining - default: - seq.sampler.SetGrammar(grammar) + // this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints + if tokenParser.ConstraintsAllowed() { + seq.grammarReady = true } // Ensure there is a place to put the sequence, released when removed from s.seqs @@ -873,7 +886,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - grammarSet := false for { select { case <-r.Context().Done(): @@ -881,6 +893,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return case content, ok := <-seq.responses: if ok { + if !seq.grammarReady { + seq.startGate.Lock() + } var thinking string var err error content, thinking, err = tokenParser.AddContent(content) @@ -890,9 +905,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } - if !grammarSet && grammar != nil && tokenParser.ConstraintsAllowed() { - seq.sampler.SetGrammar(grammar) - grammarSet = true + // only apply the grammar once + if tokenParser.ConstraintsAllowed() && !seq.grammarReady { + seq.sampler.SetGrammar(grammar, &s.mu) + seq.grammarReady = true + seq.startGate.Unlock() } if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ @@ -921,6 +938,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + if !seq.grammarReady { + seq.startGate.Unlock() + } } } } diff --git a/sample/samplers.go b/sample/samplers.go index 1c913be9f..af58730dc 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -5,6 +5,7 @@ import ( "math" "math/rand/v2" "slices" + "sync" "github.com/ollama/ollama/llama" "github.com/ollama/ollama/model" @@ -25,7 +26,9 @@ type Sampler struct { grammar *GrammarSampler } -func (s *Sampler) SetGrammar(grammar *GrammarSampler) { +func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) { + mutex.Lock() + defer mutex.Unlock() s.grammar = grammar }