tested with so
This commit is contained in:
@@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@@ -22,12 +21,15 @@ type PushdownSampler struct {
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
start := time.Now()
|
||||
// start := time.Now()
|
||||
|
||||
// fmt.Println("--------------------------------")
|
||||
// fmt.Println("PDA sampler")
|
||||
// fmt.Println("--------------------------------")
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
// before := m.Alloc
|
||||
// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
@@ -38,10 +40,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
panic(err)
|
||||
}
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
// after := m.Alloc
|
||||
// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
// fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
|
||||
return &PushdownSampler{
|
||||
curNode: startNode,
|
||||
@@ -53,6 +55,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// fmt.Println(">>> sample:", s.curNode.State)
|
||||
switch s.curNode.State {
|
||||
@@ -60,7 +63,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInListEnd:
|
||||
fmt.Println("in list end", s.braceStack)
|
||||
// fmt.Println("in list end", s.braceStack)
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
@@ -139,12 +142,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
// fmt.Println("current state - updating", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("mappedString", mappedString)
|
||||
// fmt.Println("mappedString", mappedString)
|
||||
|
||||
// TODO: should force closing for all braces - not doing square yet
|
||||
for _, r := range mappedString {
|
||||
@@ -183,23 +186,25 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
// fmt.Println("transitioning to", nextNodeState)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNodeState == s.curNode.State {
|
||||
if nextNode.State == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = s.stateToNodeMap[nextNodeState]
|
||||
s.curNode = nextNode
|
||||
// fmt.Println("updated curNode state", s.curNode.State)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||
// Should be possible through bitwise ops as well
|
||||
|
||||
Reference in New Issue
Block a user