From f52671ecc680c3d09b739b36a8570467f41c40a7 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 16 Dec 2025 15:54:34 -0800 Subject: [PATCH] refactor: bpe and spm tokenizer merges - merge candidates and pairs which are essentially the same other than the type for rank/score - use binaryheap in sentencepiece instead of implement custom structure - update merging algorithm so it uses about 15% less allocations --- model/bytepairencoding.go | 98 +++++++++++++-------------------- model/sentencepiece.go | 113 +++++++++++++------------------------- 2 files changed, 76 insertions(+), 135 deletions(-) diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index 765331bf8..8beb2ddd9 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -7,7 +7,7 @@ import ( "strings" "github.com/dlclark/regexp2" - heap "github.com/emirpasic/gods/v2/trees/binaryheap" + "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/ollama/ollama/logutil" ) @@ -84,16 +84,15 @@ type fragment struct { ids []int32 } -// pair is a pair of runes and its rank -type pair struct { - a, b int - rank int - value string +// pair is a pair of merges and its rank +type pair[T int | float32] struct { + a, b *merge + rank T } type merge struct { - p, n int - runes []rune + offset, size int + prev, next *merge } func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { @@ -156,80 +155,61 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { } runes := []rune(sb.String()) - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } + + root := &merge{offset: len(runes) - 1, size: 1} + for i := len(runes) - 2; i >= 0; i-- { + m := &merge{offset: i, size: 1, next: root} + root.prev = m + root = m } - pairwise := func(a, b int) *pair { - if a < 0 || b >= len(runes) { - return nil + pairwise := func(a, b *merge) *pair[int] { + if a != nil && b != nil { + aa := string(runes[a.offset : a.offset+a.size]) + bb := string(runes[b.offset : b.offset+b.size]) + if rank := bpe.vocab.Merge(aa, bb); rank >= 0 { + return &pair[int]{a: a, b: b, rank: rank} + } } - left, right := string(merges[a].runes), string(merges[b].runes) - rank := bpe.vocab.Merge(left, right) - if rank < 0 { - return nil - } - - return &pair{ - a: a, - b: b, - rank: rank, - value: left + right, - } + return nil } - pairs := heap.NewWith(func(i, j *pair) int { - return cmp.Compare(i.rank, j.rank) - }) - - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { + pairs := binaryheap.NewWith(func(i, j *pair[int]) int { return cmp.Compare(i.rank, j.rank) }) + for m := root; m != nil; m = m.next { + if pair := pairwise(m, m.next); pair != nil { pairs.Push(pair) } } for !pairs.Empty() { - pair, _ := pairs.Pop() - - left, right := merges[pair.a], merges[pair.b] - if len(left.runes) == 0 || len(right.runes) == 0 || - string(left.runes)+string(right.runes) != pair.value { + p, _ := pairs.Pop() + a := string(runes[p.a.offset : p.a.offset+p.a.size]) + b := string(runes[p.b.offset : p.b.offset+p.b.size]) + if a == "" || b == "" || bpe.vocab.Merge(a, b) != p.rank { continue } - if id := bpe.vocab.Encode(pair.value); id < 0 { - continue + p.a.size += p.b.size + p.b.size = 0 + + p.a.next = p.b.next + if p.b.next != nil { + p.b.next.prev = p.a } - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a - } - - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { + if pair := pairwise(p.a.prev, p.a); pair != nil { pairs.Push(pair) } - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { + if pair := pairwise(p.a, p.a.next); pair != nil { pairs.Push(pair) } } - for _, merge := range merges { - if len(merge.runes) > 0 { - // TODO: handle the edge case where the rune isn't in the vocabulary - if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { - ids = append(ids, id) - } + for m := root; m != nil; m = m.next { + if id := bpe.vocab.Encode(string(runes[m.offset : m.offset+m.size])); id >= 0 { + ids = append(ids, id) } } } diff --git a/model/sentencepiece.go b/model/sentencepiece.go index 2c178ec0c..49dd38f98 100644 --- a/model/sentencepiece.go +++ b/model/sentencepiece.go @@ -1,12 +1,13 @@ package model import ( - "container/heap" + "cmp" "fmt" "log/slog" "strconv" "strings" + "github.com/emirpasic/gods/v2/trees/binaryheap" "github.com/ollama/ollama/logutil" ) @@ -94,79 +95,68 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { continue } - q := &queue{} - heap.Init(q) - runes := []rune(text) - merges := make([]merge, len(runes)) - for r := range runes { - merges[r] = merge{ - p: r - 1, - n: r + 1, - runes: []rune{runes[r]}, - } + + root := &merge{offset: len(runes) - 1, size: 1} + for i := len(runes) - 2; i >= 0; i-- { + m := &merge{offset: i, size: 1, next: root} + root.prev = m + root = m } - pairwise := func(a, b int) *candidate { - if a < 0 || b >= len(runes) { - return nil - } - - left, right := string(merges[a].runes), string(merges[b].runes) - if id := spm.vocab.Encode(left + right); id >= 0 { - return &candidate{ - a: a, - b: b, - score: spm.vocab.Scores[id], - size: len(left) + len(right), + pairwise := func(a, b *merge) *pair[float32] { + if a != nil && b != nil { + aa := string(runes[a.offset : a.offset+a.size]) + bb := string(runes[b.offset : b.offset+b.size]) + if id := spm.vocab.Encode(aa + bb); id >= 0 { + return &pair[float32]{a: a, b: b, rank: spm.vocab.Scores[id]} } } return nil } - for i := range len(runes) - 1 { - if pair := pairwise(i, i+1); pair != nil { - heap.Push(q, pair) + pairs := binaryheap.NewWith(func(i, j *pair[float32]) int { return cmp.Compare(i.rank, j.rank) }) + for m := root; m != nil; m = m.next { + if pair := pairwise(m, m.next); pair != nil { + pairs.Push(pair) } } - for q.Len() > 0 { - pair := heap.Pop(q).(*candidate) - left, right := merges[pair.a], merges[pair.b] - - if string(left.runes) == "" || string(right.runes) == "" || len(string(left.runes))+len(string(right.runes)) != pair.size { + for !pairs.Empty() { + p, _ := pairs.Pop() + a := string(runes[p.a.offset : p.a.offset+p.a.size]) + b := string(runes[p.b.offset : p.b.offset+p.b.size]) + if id := spm.vocab.Encode(a + b); a == "" || b == "" || id < 0 || spm.vocab.Scores[id] != p.rank { continue } - merges[pair.a].runes = append(left.runes, right.runes...) - merges[pair.b].runes = nil - merges[pair.a].n = right.n - if right.n < len(merges) { - merges[right.n].p = pair.a + p.a.size += p.b.size + p.b.size = 0 + + p.a.next = p.b.next + if p.b.next != nil { + p.b.next.prev = p.a } - if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { - heap.Push(q, pair) + if pair := pairwise(p.a.prev, p.a); pair != nil { + pairs.Push(pair) } - if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { - heap.Push(q, pair) + if pair := pairwise(p.a, p.a.next); pair != nil { + pairs.Push(pair) } } - for _, merge := range merges { - if token := string(merge.runes); token != "" { - id := spm.vocab.Encode(token) - - if id >= 0 { + for m := root; m != nil; m = m.next { + if s := string(runes[m.offset : m.offset+m.size]); s != "" { + if id := spm.vocab.Encode(s); id >= 0 { ids = append(ids, id) continue } - // Fallback to byte tokenization var result []int32 - for _, b := range []byte(token) { + for _, b := range []byte(s) { byteToken := fmt.Sprintf("<0x%02X>", b) unknownID := spm.vocab.Encode(byteToken) if unknownID >= 0 { @@ -189,35 +179,6 @@ func (spm SentencePiece) Encode(s string, addSpecial bool) ([]int32, error) { return ids, nil } -type candidate struct { - a, b int - score float32 - size int -} - -type queue []*candidate - -func (q queue) Len() int { return len(q) } - -func (q queue) Less(i, j int) bool { - return (q[i].score > q[j].score) || (q[i].score == q[j].score && q[i].a < q[j].a) -} - -func (q queue) Swap(i, j int) { q[i], q[j] = q[j], q[i] } - -func (q *queue) Push(x interface{}) { - item := x.(*candidate) - *q = append(*q, item) -} - -func (q *queue) Pop() interface{} { - old := *q - n := len(old) - item := old[n-1] - *q = old[0 : n-1] - return item -} - func (spm SentencePiece) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids {