Files
ollama/sample/fast_json.go
2025-01-29 14:28:02 -08:00

359 lines
7.6 KiB
Go

package sample
import (
"errors"
"fmt"
"math"
"github.com/ollama/ollama/model"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInColon
StateInComma
StateInTab
StateInSpace
StateInObjSpace
StateInList
StateInListComma
StateListEnd
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
)
var JSONStates = []JSONState{
StateStart,
StateInObject,
StateInObjectKey,
StateNewline,
StateTab,
StateSpace,
StateInString,
StateInInt,
StateInFloat,
StateInBool,
StateInNull,
StateInColon,
StateInComma,
StateInTab,
StateInSpace,
StateInObjSpace,
StateInList,
StateInListComma,
StateListEnd,
StateInValue,
StateInValueEnd,
StateInListEnd,
StateInListObjectEnd,
StateInNewline,
StateInNumber,
StateInNumberEnd,
StateInStringEnd,
StateInObjectKeyEnd,
StateTerminate,
StateInObjectEnd,
StateTransitioningToTerminate,
}
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpace:
return "StateInSpace"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInListComma:
return "StateInListComma"
case StateListEnd:
return "StateListEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}
type JSONSampler struct {
curNode *Node
proc model.TextProcessor
stack []*Node
bracketCounter int
}
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
// fmt.Println("Creating new JSON sampler")
startNode, err := buildStateMachine(proc)
if err != nil {
return nil, err
}
js := &JSONSampler{
curNode: startNode,
proc: proc,
stack: []*Node{},
bracketCounter: 0,
}
return js, nil
}
func isTokenSubset(subset, superset []int32) bool {
freq1 := make(map[int32]int)
freq2 := make(map[int32]int)
for _, v := range subset {
freq1[v]++
}
for _, v := range superset {
freq2[v]++
}
isSubset := true
for k, count1 := range freq1 {
count2 := freq2[k]
if count1 > count2 {
isSubset = false
break
}
}
return isSubset
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
// fmt.Printf("Current state: %s\n", s.curNode.State)
// fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
if err != nil {
return err
}
// only move to terminate state if stack is empty
if s.curNode.State == StateInObjectEnd {
fmt.Println("debug: node.State", s.curNode.State)
if len(s.stack) > 0 {
s.stack = s.stack[:len(s.stack)-1]
fmt.Println("popped and cur state", s.curNode.State)
return nil
}
return nil
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if isTokenSubset(tokenSlice, validToken) {
s.curNode = node
for _, token := range objectTokens {
if isTokenSubset(tokenSlice, token) {
fmt.Println("Appending to stack", s.curNode.State)
s.stack = append(s.stack, s.curNode)
}
}
// fmt.Printf("Transitioned to state: %s\n", node.State)
return nil
}
}
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
s.curNode = node
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
return nil
}
}
}
fmt.Println("invalid token ", tokenSlice)
dec, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println("decoded token ", dec)
return errors.New("invalid token")
}
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
var err error
switch s.curNode.State {
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
case StateInInt:
validStates := []int32{}
minus, err := s.proc.Encode("-")
if err != nil {
return nil, err
}
digits := make([][]int32, 10)
for i := 0; i < 10; i++ {
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
if err != nil {
return nil, err
}
}
// Allow "-" and digits 0-9 at start
for i := range logits {
for _, d := range digits {
if len(d) == 1 && int32(i) == d[0] {
validStates = append(validStates, int32(i))
}
}
if len(minus) == 1 && int32(i) == minus[0] {
validStates = append(validStates, int32(i))
}
}
return logits, nil
case StateInString:
penalizeNewlineVariants := []string{"\n", " \"\n"}
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
if err != nil {
return nil, err
}
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
if err != nil {
return nil, err
}
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
default:
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
}
}
func getValidStates(node *Node) []int32 {
validStates := []int32{}
for _, edge := range node.TransitionEdges {
for _, token := range edge {
validStates = append(validStates, token...)
}
}
return validStates
}
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
// todo: this can prob be more efficient
for i := range logits {
isValid := false
for _, token := range validStates {
if token == -1 {
// fmt.Println("Found sentinel token, returning unmasked logits")
return logits, nil
}
if i == int(token) {
// fmt.Printf("Found valid token: %d\n", token)
isValid = true
break
}
}
if !isValid {
logits[i] = math.NaN()
}
}
return logits, nil
}
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
for i := range logits {
for _, token := range tokensToMask {
for _, chunked := range token {
if int(chunked) == i {
logits[i] = math.NaN()
}
}
}
}
return logits, nil
}