saved state
This commit is contained in:
@@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// fmt.Println(">>> sample:", s.curNode.State)
|
||||
switch s.curNode.State {
|
||||
@@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Println("pushing [ brace stack", r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('{') {
|
||||
if top != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
@@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 {
|
||||
return fmt.Errorf("stack is empty, extra closing brace %c", r)
|
||||
}
|
||||
top := s.braceStack[len(s.braceStack)-1]
|
||||
if len(s.braceStack) == 0 || top != rune('[') {
|
||||
if top != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
@@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||
// Should be possible through bitwise ops as well
|
||||
for i := range logits {
|
||||
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||
if !exists {
|
||||
|
||||
Reference in New Issue
Block a user