WIP
This commit is contained in:
@@ -3,6 +3,8 @@ package sample
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@@ -13,9 +15,17 @@ type PushdownSampler struct {
|
||||
proc model.TextProcessor
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
start := time.Now()
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||
// token, err := proc.Decode([]int32{int32(id)})
|
||||
// if err != nil {
|
||||
@@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
proc: proc,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
// return logits, nil
|
||||
|
||||
case StateInComma:
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
fmt.Println("switching to list comma", s.curNode.State)
|
||||
}
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
@@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
// case StateInStringEnd:
|
||||
|
||||
// return logits, nil
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
@@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
|
||||
// TODO: need to handle end states and entering object case
|
||||
// TODO: need to handle end states and entering object case, and list case
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("in object end")
|
||||
if len(s.braceStack) > 0 {
|
||||
@@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: should force closing for all braces
|
||||
for _, r := range mappedString {
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('[') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
}
|
||||
}
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNode)
|
||||
s.curNode = s.stateToNodeMap[nextNode]
|
||||
fmt.Println("transitioning to", nextNodeState)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNodeState == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = s.stateToNodeMap[nextNodeState]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user