WIP simple SO working

This commit is contained in:
ParthSareen
2025-02-03 11:40:06 -08:00
parent 524029cd6d
commit d5f8670f0a
3 changed files with 32 additions and 37 deletions

View File

@@ -4,6 +4,7 @@ import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
@@ -21,15 +22,15 @@ type PushdownSampler struct {
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
// start := time.Now()
start := time.Now()
// fmt.Println("--------------------------------")
// fmt.Println("PDA sampler")
// fmt.Println("--------------------------------")
fmt.Println("--------------------------------")
fmt.Println("PDA sampler")
fmt.Println("--------------------------------")
var m runtime.MemStats
runtime.ReadMemStats(&m)
// before := m.Alloc
// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil {
@@ -40,10 +41,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
panic(err)
}
runtime.ReadMemStats(&m)
// after := m.Alloc
// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Graph build time = %v\n", time.Since(start))
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
return &PushdownSampler{
curNode: startNode,
@@ -57,13 +58,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
// TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
// fmt.Println(">>> sample:", s.curNode.State)
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
// fmt.Println("in list end", s.braceStack)
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
@@ -100,7 +99,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
// fmt.Println("switching to list object end", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
@@ -113,7 +111,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
// fmt.Println("switching to list comma", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
@@ -132,7 +129,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
return logits, nil
default:
// fmt.Println("masking logits current state", s.curNode.State)
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
@@ -142,22 +139,20 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("current state - updating", s.curNode.State)
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
// fmt.Println("mappedString", mappedString)
fmt.Println(">>> mappedString", mappedString)
// TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString {
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing { brace stack", r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
// fmt.Println("pushing [ brace stack", r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
@@ -168,7 +163,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
// fmt.Println("popping { brace stack", top)
}
if r == rune(']') {
@@ -180,7 +174,6 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
// fmt.Println("popping [ brace stack", top)
}
}
@@ -190,7 +183,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
// fmt.Println("transitioning to", nextNodeState)
fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State {
@@ -199,7 +192,7 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
s.stateCounter = 0
}
s.curNode = nextNode
// fmt.Println("updated curNode state", s.curNode.State)
fmt.Println("updated curNode state", s.curNode.State)
}
return nil
}