runner/parser: allow on-the-fly grammar constraining
This commit is contained in:
parent
8d6fffaead
commit
1e5fecbbc3
|
|
@ -47,12 +47,13 @@ func (s harmonyParserState) String() string {
|
|||
}
|
||||
|
||||
type HarmonyParser struct {
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
state harmonyParserState
|
||||
MessageStartTag string
|
||||
MessageEndTag string
|
||||
HeaderEndTag string
|
||||
constraintsAllowed bool
|
||||
acc strings.Builder
|
||||
lifetimeAcc strings.Builder
|
||||
}
|
||||
|
||||
type HarmonyEvent interface {
|
||||
|
|
@ -89,6 +90,10 @@ func (s *HarmonyParser) AddImplicitStart() {
|
|||
s.acc.WriteString("<|start|>assistant")
|
||||
}
|
||||
|
||||
func (s *HarmonyParser) ConstraintsAllowed() bool {
|
||||
return s.constraintsAllowed
|
||||
}
|
||||
|
||||
func Prefill(lastMessage api.Message) string {
|
||||
if lastMessage.Role != "assistant" {
|
||||
return ""
|
||||
|
|
@ -341,6 +346,7 @@ func (h *HarmonyMessageHandler) AddContent(content string) (string, string, stri
|
|||
}
|
||||
case "final":
|
||||
h.state = harmonyMessageState_Normal
|
||||
h.HarmonyParser.constraintsAllowed = true
|
||||
}
|
||||
case HarmonyEventContentEmitted:
|
||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ type MessageHandler interface {
|
|||
|
||||
type ParserInternals interface {
|
||||
AddImplicitStartOrPrefill(prefillString string)
|
||||
ConstraintsAllowed() bool
|
||||
}
|
||||
|
||||
type ToolParser interface {
|
||||
|
|
@ -51,6 +52,10 @@ type defaultEngine struct{}
|
|||
|
||||
func (defaultEngine) AddImplicitStartOrPrefill(prefillString string) {}
|
||||
|
||||
func (defaultEngine) ConstraintsAllowed() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
type defaultToolParser struct{}
|
||||
|
||||
func (defaultToolParser) Add(token string) {}
|
||||
|
|
@ -104,6 +109,10 @@ func (p *TokenParser) repeatLimitReached(token string) bool {
|
|||
return p.tokenRepeat >= p.repeatLimit
|
||||
}
|
||||
|
||||
func (p *TokenParser) ConstraintsAllowed() bool {
|
||||
return p.parserEngine.ConstraintsAllowed()
|
||||
}
|
||||
|
||||
// TODO: update to work with multiple toolcalls - unmarshalling should also happen on parser level
|
||||
func (p *TokenParser) Drain() []api.ToolCall {
|
||||
toolName, toolContent := p.toolParser.Drain()
|
||||
|
|
|
|||
|
|
@ -782,8 +782,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||
|
||||
if req.Options == nil {
|
||||
opts := api.DefaultOptions()
|
||||
req.Options = &opts
|
||||
|
|
@ -816,7 +814,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
req.Options.TopP,
|
||||
req.Options.MinP,
|
||||
req.Options.Seed,
|
||||
grammar,
|
||||
nil,
|
||||
)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
|
|
@ -831,6 +829,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
tokenParser := parser.NewTokenParser(req.ParserType, req.PrefillString)
|
||||
switch req.ParserType {
|
||||
case parser.TokenParserTypeHarmony:
|
||||
// Do not set grammar until model allows constraining
|
||||
default:
|
||||
seq.sampler.SetGrammar(grammar)
|
||||
}
|
||||
|
||||
// Ensure there is a place to put the sequence, released when removed from s.seqs
|
||||
if err := s.seqsSem.Acquire(r.Context(), 1); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
|
|
@ -867,6 +873,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
grammarSet := false
|
||||
for {
|
||||
select {
|
||||
case <-r.Context().Done():
|
||||
|
|
@ -883,6 +890,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
|||
return
|
||||
}
|
||||
|
||||
if !grammarSet && grammar != nil && tokenParser.ConstraintsAllowed() {
|
||||
seq.sampler.SetGrammar(grammar)
|
||||
grammarSet = true
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||
Content: content,
|
||||
Thinking: thinking,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,10 @@ type Sampler struct {
|
|||
grammar *GrammarSampler
|
||||
}
|
||||
|
||||
func (s *Sampler) SetGrammar(grammar *GrammarSampler) {
|
||||
s.grammar = grammar
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
|
|
|
|||
Loading…
Reference in New Issue