sampler/runner: enable gpt-oss structured outputs
This commit is contained in:
parent
40d3436cd1
commit
1fe7e07f63
|
|
@ -48,12 +48,13 @@ func (s harmonyParserState) String() string {
|
||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyParser struct {
|
type HarmonyParser struct {
|
||||||
state harmonyParserState
|
state harmonyParserState
|
||||||
MessageStartTag string
|
MessageStartTag string
|
||||||
MessageEndTag string
|
MessageEndTag string
|
||||||
HeaderEndTag string
|
HeaderEndTag string
|
||||||
acc strings.Builder
|
ConstrainAllowed bool
|
||||||
lifetimeAcc strings.Builder
|
acc strings.Builder
|
||||||
|
lifetimeAcc strings.Builder
|
||||||
}
|
}
|
||||||
|
|
||||||
type HarmonyEvent interface {
|
type HarmonyEvent interface {
|
||||||
|
|
@ -328,6 +329,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
|
||||||
}
|
}
|
||||||
case "final":
|
case "final":
|
||||||
h.state = harmonyMessageState_Normal
|
h.state = harmonyMessageState_Normal
|
||||||
|
h.HarmonyParser.ConstrainAllowed = true
|
||||||
}
|
}
|
||||||
case HarmonyEventContentEmitted:
|
case HarmonyEventContentEmitted:
|
||||||
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
logutil.Trace("harmony event content", "content", event.Content, "state", h.state)
|
||||||
|
|
|
||||||
|
|
@ -814,7 +814,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
req.Options.TopP,
|
req.Options.TopP,
|
||||||
req.Options.MinP,
|
req.Options.MinP,
|
||||||
req.Options.Seed,
|
req.Options.Seed,
|
||||||
grammar,
|
nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||||
|
|
@ -865,6 +865,12 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(parthsareen): generalize grammar enablement on the fly for all thinking models
|
||||||
|
if harmonyMessageHandler == nil {
|
||||||
|
seq.sampler.SetGrammar(grammar)
|
||||||
|
}
|
||||||
|
|
||||||
|
grammarSet := false
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-r.Context().Done():
|
case <-r.Context().Done():
|
||||||
|
|
@ -877,6 +883,10 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||||
var toolContent string
|
var toolContent string
|
||||||
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
|
||||||
harmonyToolParser.Add(toolContent)
|
harmonyToolParser.Add(toolContent)
|
||||||
|
if harmonyMessageHandler.HarmonyParser.ConstrainAllowed && !grammarSet {
|
||||||
|
seq.sampler.SetGrammar(grammar)
|
||||||
|
grammarSet = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,10 @@ type Sampler struct {
|
||||||
grammar *GrammarSampler
|
grammar *GrammarSampler
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Sampler) SetGrammar(grammar *GrammarSampler) {
|
||||||
|
s.grammar = grammar
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||||
if len(logits) == 0 {
|
if len(logits) == 0 {
|
||||||
return -1, errors.New("sample: no logits provided to sample")
|
return -1, errors.New("sample: no logits provided to sample")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue