saved state
This commit is contained in:
@@ -44,8 +44,6 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
// consider adding a node to just point to values, could be good to compute that
|
||||
// mask rather than many different nodes
|
||||
|
||||
// Connect nodes
|
||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
|
||||
@@ -161,6 +159,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
// TODO: tough life fr. plz fix.
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
@@ -176,33 +175,42 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
curNode := node
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||
}
|
||||
err = createMask(node, proc, decodedToks, vocab)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// garbage interface plz fix
|
||||
func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
curNode := node
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
var err error
|
||||
for _, r := range token {
|
||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: garbage interface plz fix
|
||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return false, nil, nil
|
||||
|
||||
@@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// fmt.Println(">>> sample:", s.curNode.State)
|
||||
switch s.curNode.State {
|
||||
@@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Println("pushing [ brace stack", r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('{') {
|
||||
if top != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
@@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('[') {
|
||||
if top != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
@@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||
// Should be possible through bitwise ops as well
|
||||
for i := range logits {
|
||||
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||
if !exists {
|
||||
|
||||
@@ -165,11 +165,12 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
||||
if len(logitsCopy) == 0 {
|
||||
return nil, errors.New("no valid tokens found")
|
||||
}
|
||||
logitsCopy, err := computeSoftmax(logitsCopy)
|
||||
|
||||
softmax, err := computeSoftmax(logitsCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
||||
w := sampleuv.NewWeighted(softmax, nil)
|
||||
if v, ok := w.Take(); ok {
|
||||
// returns the token ID
|
||||
return []float64{float64(indices[v])}, nil
|
||||
|
||||
@@ -3,10 +3,52 @@ package sample
|
||||
import "github.com/ollama/ollama/model"
|
||||
|
||||
type StructuredOutput struct {
|
||||
schema *Schema
|
||||
schema *Schema
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
}
|
||||
|
||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
|
||||
_, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return &StructuredOutput{
|
||||
schema: schema,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
}
|
||||
}
|
||||
|
||||
func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
|
||||
|
||||
schemaType := so.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// each prop is a key
|
||||
// prevState := StateInObjectKey
|
||||
for _, prop := range so.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
prevState := StateInObjectKey
|
||||
for i, r := range name {
|
||||
newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
|
||||
|
||||
// Create new node for this state if it doesn't exist
|
||||
if _, exists := so.stateToNodeMap[newState]; !exists {
|
||||
so.stateToNodeMap[newState] = &PDANode{
|
||||
State: newState,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect previous state to this state via the rune
|
||||
so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
|
||||
prevState = newState
|
||||
}
|
||||
// type of value
|
||||
// propType := prop.Type
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user