runner: add sync between computeBatch and completion

This commit is contained in:
ParthSareen 2025-09-10 18:50:01 -07:00
parent 1e5fecbbc3
commit c0aeb3531b
2 changed files with 33 additions and 10 deletions

View File

@ -62,6 +62,11 @@ type Sequence struct {
// tokens that have been generated but not returned yet (e.g. for stop sequences) // tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string pendingResponses []string
// startGate
startGate *sync.Mutex
grammarReady bool
// input cache being used by this sequence // input cache being used by this sequence
cache *InputCacheSlot cache *InputCacheSlot
@ -164,6 +169,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar // TODO(jessegross): Ingest cached history for grammar
startGate := &sync.Mutex{}
return &Sequence{ return &Sequence{
ctxs: ctxs, ctxs: ctxs,
mmStore: mmStore, mmStore: mmStore,
@ -179,6 +185,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding, embeddingOnly: params.embedding,
stop: params.stop, stop: params.stop,
numKeep: params.numKeep, numKeep: params.numKeep,
startGate: startGate,
grammarReady: false,
}, nil }, nil
} }
@ -707,11 +715,18 @@ func (s *Server) computeBatch(activeBatch batchState) {
// sample a token // sample a token
vocabSize := len(outputs) / len(activeBatch.batch.Outputs) vocabSize := len(outputs) / len(activeBatch.batch.Outputs)
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches) logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", len(activeBatch.batch.Outputs), "vocabSize", vocabSize, "iBatches", iBatches)
if !seq.grammarReady {
seq.startGate.Lock()
}
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize]) token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil { if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err) s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return return
} }
if !seq.grammarReady {
seq.startGate.Unlock()
}
nextBatchTokens[i].Token = token nextBatchTokens[i].Token = token
@ -830,11 +845,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
} }
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString) tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
switch req.ParserType { // this accounts for the default case and also the case where there is a prefill which moves the state of the parser to allow for constraints
case parser.TokenParserTypeHarmony: if tokenParser.ConstraintsAllowed() {
// Do not set grammar until model allows constraining seq.grammarReady = true
default:
seq.sampler.SetGrammar(grammar)
} }
// Ensure there is a place to put the sequence, released when removed from s.seqs // Ensure there is a place to put the sequence, released when removed from s.seqs
@ -873,7 +886,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
grammarSet := false
for { for {
select { select {
case <-r.Context().Done(): case <-r.Context().Done():
@ -881,6 +893,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
case content, ok := <-seq.responses: case content, ok := <-seq.responses:
if ok { if ok {
if !seq.grammarReady {
seq.startGate.Lock()
}
var thinking string var thinking string
var err error var err error
content, thinking, err = tokenParser.AddContent(content) content, thinking, err = tokenParser.AddContent(content)
@ -890,9 +905,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
if !grammarSet && grammar != nil && tokenParser.ConstraintsAllowed() { // only apply the grammar once
seq.sampler.SetGrammar(grammar) if tokenParser.ConstraintsAllowed() && !seq.grammarReady {
grammarSet = true seq.sampler.SetGrammar(grammar, &s.mu)
seq.grammarReady = true
seq.startGate.Unlock()
} }
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{ if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
@ -921,6 +938,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return return
} }
if !seq.grammarReady {
seq.startGate.Unlock()
}
} }
} }
} }

View File

@ -5,6 +5,7 @@ import (
"math" "math"
"math/rand/v2" "math/rand/v2"
"slices" "slices"
"sync"
"github.com/ollama/ollama/llama" "github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
@ -25,7 +26,9 @@ type Sampler struct {
grammar *GrammarSampler grammar *GrammarSampler
} }
func (s *Sampler) SetGrammar(grammar *GrammarSampler) { func (s *Sampler) SetGrammar(grammar *GrammarSampler, mutex *sync.Mutex) {
mutex.Lock()
defer mutex.Unlock()
s.grammar = grammar s.grammar = grammar
} }