saved state

This commit is contained in:
ParthSareen
2025-01-30 15:17:42 -08:00
parent c56a8b7749
commit b973dedb4b
7 changed files with 130 additions and 45 deletions

View File

@@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
}
}
// TODO: need to add resampling logic if the first sample was not good
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
// fmt.Println(">>> sample:", s.curNode.State)
switch s.curNode.State {
@@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("pushing [ brace stack", r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
return fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if len(s.braceStack) == 0 || top != rune('{') {
if top != rune('{') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
@@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
}
if r == rune(']') {
if len(s.braceStack) == 0 {
return fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if len(s.braceStack) == 0 || top != rune('[') {
if top != rune('[') {
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
@@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
}
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
// Should be possible through bitwise ops as well
for i := range logits {
_, exists := node.MaskTokenIDToNode[int32(i)]
if !exists {