Compare commits

..

6 Commits

Author SHA1 Message Date
Michael Yang
021dcf089d Merge pull request #9824 from ollama/mxyng/sched
conditionally enable parallel pipelines
2025-03-17 15:41:37 -07:00
Jesse Gross
bf24498b1e ollamarunner: Check for minBatch of context space when shifting
Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.
2025-03-17 15:33:16 -07:00
Bruce MacDonald
95e271d98f runner: remove cache prompt flag from ollama runner (#9826)
We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.
2025-03-17 15:11:15 -07:00
Jeffrey Morgan
364629b8d6 ml/backend/ggml: allocate memory with malloc when loading model (#9822) 2025-03-17 13:32:40 -07:00
Parth Sareen
108fe02165 sample: make mutations in transforms explicit (#9743)
* updated minP to use early exit making use of sorted tokens
2025-03-17 11:24:18 -07:00
Michael Yang
4561fff36e conditionally enable parallel pipelines 2025-03-17 09:46:07 -07:00
13 changed files with 269 additions and 884 deletions

View File

@@ -184,8 +184,6 @@ func ConvertModel(fsys fs.FS, ws io.WriteSeeker) error {
switch p.Architectures[0] {
case "LlamaForCausalLM", "MistralForCausalLM":
conv = &llamaModel{}
case "Mistral3ForConditionalGeneration":
conv = &mistralModel{}
case "MixtralForCausalLM":
conv = &mixtralModel{}
case "GemmaForCausalLM":

View File

@@ -1,246 +0,0 @@
package convert
import (
"cmp"
"fmt"
"math"
"strings"
"github.com/pdevine/tensor"
"github.com/pdevine/tensor/native"
"github.com/ollama/ollama/fs/ggml"
)
type mistralModel struct {
ModelParameters
// Text model parameters
TextConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RopeTheta float32 `json:"rope_theta"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
} `json:"text_config"`
// Vision model parameters
VisionConfig struct {
NumHiddenLayers uint32 `json:"num_hidden_layers"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
ImageSize uint32 `json:"image_size"`
PatchSize uint32 `json:"patch_size"`
RopeTheta float32 `json:"rope_theta"`
} `json:"vision_config"`
// Multimodal specific parameters
ImageTokenIndex uint32 `json:"image_token_index"`
MultimodalProjectorBias bool `json:"multimodal_projector_bias"`
ProjectorHiddenAct string `json:"projector_hidden_act"`
SpatialMergeSize uint32 `json:"spatial_merge_size"`
VisionFeatureLayer int32 `json:"vision_feature_layer"`
// For RoPE scaling if needed
RopeScaling struct {
Type string `json:"type"`
RopeType string `json:"rope_type"`
Factor float32 `json:"factor"`
LowFrequencyFactor float32 `json:"low_freq_factor"`
HighFrequencyFactor float32 `json:"high_freq_factor"`
OriginalMaxPositionalEmbeddings uint32 `json:"original_max_positional_embeddings"`
factors ropeFactor
} `json:"rope_scaling"`
}
func (p *mistralModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "mistral"
kv["mistral.vocab_size"] = p.VocabSize
kv["mistral.image_token_index"] = p.ImageTokenIndex
kv["mistral.multimodal_projector_bias"] = p.MultimodalProjectorBias
kv["mistral.projector_hidden_act"] = p.ProjectorHiddenAct
kv["mistral.spatial_merge_size"] = p.SpatialMergeSize
// kv["mistral.vision_feature_layer"] = p.VisionFeatureLayer
// Text model config
kv["mistral.block_count"] = p.TextConfig.NumHiddenLayers
kv["mistral.context_length"] = p.TextConfig.MaxPositionEmbeddings
kv["mistral.embedding_length"] = p.TextConfig.HiddenSize
kv["mistral.feed_forward_length"] = p.TextConfig.IntermediateSize
kv["mistral.attention.head_count"] = p.TextConfig.NumAttentionHeads
kv["mistral.attention.head_count_kv"] = p.TextConfig.NumKeyValueHeads
kv["mistral.rope.dimension_count"] = p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
kv["mistral.rope.freq_base"] = p.TextConfig.RopeTheta
kv["mistral.attention.layer_norm_rms_epsilon"] = p.TextConfig.RMSNormEPS
kv["mistral.attention.key_length"] = p.TextConfig.HeadDim
kv["mistral.attention.value_length"] = p.TextConfig.HeadDim
// Vision model config
kv["mistral.vision.block_count"] = p.VisionConfig.NumHiddenLayers
kv["mistral.vision.embedding_length"] = p.VisionConfig.HiddenSize
kv["mistral.vision.feed_forward_length"] = p.VisionConfig.IntermediateSize
kv["mistral.vision.attention.head_count"] = p.VisionConfig.NumAttentionHeads
kv["mistral.vision.image_size"] = p.VisionConfig.ImageSize
kv["mistral.vision.patch_size"] = p.VisionConfig.PatchSize
kv["mistral.vision.rope.freq_base"] = p.VisionConfig.RopeTheta
// If RoPE scaling is present
if p.RopeScaling.Type == "linear" {
kv["mistral.rope.scaling.type"] = p.RopeScaling.Type
kv["mistral.rope.scaling.factor"] = p.RopeScaling.Factor
} else if p.RopeScaling.RopeType == "llama3" {
dim := p.TextConfig.HiddenSize / p.TextConfig.NumAttentionHeads
for i := uint32(0); i < dim; i += 2 {
factor := cmp.Or(p.RopeScaling.Factor, 8.0)
factorLow := cmp.Or(p.RopeScaling.LowFrequencyFactor, 1.0)
factorHigh := cmp.Or(p.RopeScaling.HighFrequencyFactor, 4.0)
original := cmp.Or(p.RopeScaling.OriginalMaxPositionalEmbeddings, 8192)
lambdaLow := float32(original) / factorLow
lambdaHigh := float32(original) / factorHigh
lambda := 2 * math.Pi * math.Pow(float64(p.TextConfig.RopeTheta), float64(i)/float64(dim))
if lambda < float64(lambdaHigh) {
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0)
} else if lambda > float64(lambdaLow) {
p.RopeScaling.factors = append(p.RopeScaling.factors, factor)
} else {
smooth := (float32(original)/float32(lambda) - factorLow) / (factorHigh - factorLow)
p.RopeScaling.factors = append(p.RopeScaling.factors, 1.0/((1-smooth)/factor+smooth))
}
}
}
return kv
}
func (p *mistralModel) Tensors(ts []Tensor) []ggml.Tensor {
var out []ggml.Tensor
if p.RopeScaling.factors != nil {
out = append(out, ggml.Tensor{
Name: "rope_freqs.weight",
Kind: 0,
Shape: []uint64{uint64(len(p.RopeScaling.factors))},
WriterTo: p.RopeScaling.factors,
})
}
for _, t := range ts {
// Process tensors that require repacking
if strings.HasSuffix(t.Name(), "attn_q.weight") ||
strings.HasSuffix(t.Name(), "attn_k.weight") {
t.SetRepacker(p.repack)
}
// Add all tensors to output
out = append(out, ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (p *mistralModel) Replacements() []string {
return []string{
// Language model replacements
"language_model.model.embed_tokens", "token_embd",
"language_model.model.norm", "output_norm",
"language_model.model.layers", "blk",
"language_model.model.layers.*.input_layernorm", "input_layernorm",
"language_model.model.layers.*.self_attn.q_proj", "self_attn.q_proj",
"language_model.model.layers.*.self_attn.k_proj", "self_attn.k_proj",
"language_model.model.layers.*.self_attn.v_proj", "self_attn.v_proj",
"language_model.model.layers.*.self_attn.o_proj", "self_attn.o_proj",
"language_model.model.layers.*.mlp.gate_proj", "mlp.gate_proj",
"language_model.model.layers.*.mlp.down_proj", "mlp.down_proj",
"language_model.model.layers.*.mlp.up_proj", "mlp.up_proj",
"language_model.model.layers.*.post_attention_layernorm", "post_attention_layernorm",
"language_model.lm_head", "output",
// Vision model replacements - map to shorter prefixes
"vision_tower", "v",
"multi_modal_projector", "mm",
// Vision transformer blocks - these should be updated accordingly
"vision_tower.transformer.layers", "v.blk",
"vision_tower.transformer.layers.*.attention_norm", "v.attn_norm",
"vision_tower.transformer.layers.*.attention.q_proj", "v.attn_q",
"vision_tower.transformer.layers.*.attention.k_proj", "v.attn_k",
"vision_tower.transformer.layers.*.attention.v_proj", "v.attn_v",
"vision_tower.transformer.layers.*.attention.o_proj", "v.attn_output",
"vision_tower.transformer.layers.*.feed_forward.gate_proj", "v.ffn_gate",
"vision_tower.transformer.layers.*.feed_forward.down_proj", "v.ffn_down",
"vision_tower.transformer.layers.*.feed_forward.up_proj", "v.ffn_up",
"vision_tower.transformer.layers.*.ffn_norm", "v.ffn_norm",
"vision_tower.ln_pre", "v.encoder_norm",
"vision_tower.patch_conv", "v.patch_conv",
// Multimodal projector components
"multi_modal_projector.patch_merger", "mm.patch_merger",
"multi_modal_projector.norm", "mm.norm",
}
}
func (p *mistralModel) repack(name string, data []float32, shape []uint64) ([]float32, error) {
var dims []int
for _, dim := range shape {
dims = append(dims, int(dim))
}
var heads uint32
if strings.HasSuffix(name, "attn_q.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = p.TextConfig.NumAttentionHeads
}
} else if strings.HasSuffix(name, "attn_k.weight") {
if strings.Contains(name, "vision") {
heads = p.VisionConfig.NumAttentionHeads
} else {
heads = cmp.Or(p.TextConfig.NumKeyValueHeads, p.TextConfig.NumAttentionHeads)
}
} else {
return nil, fmt.Errorf("unknown tensor for repack: %s", name)
}
n := tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data))
if err := n.Reshape(append([]int{int(heads), 2, dims[0] / int(heads) / 2}, dims[1:]...)...); err != nil {
return nil, err
}
if err := n.T(0, 2, 1, 3); err != nil {
return nil, err
}
if err := n.Reshape(dims...); err != nil {
return nil, err
}
if err := n.Transpose(); err != nil {
return nil, err
}
ts, err := native.SelectF32(n, 1)
if err != nil {
return nil, err
}
var f32s []float32
for _, t := range ts {
f32s = append(f32s, t...)
}
return f32s, nil
}

View File

@@ -312,17 +312,19 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return fmt.Errorf("unassigned tensor: %s", t.Name)
}
bts := make([]byte, t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), bts)
if err != nil {
return err
bts := C.malloc(C.size_t(t.Size()))
if bts == nil {
return errors.New("failed to allocate tensor buffer")
}
defer C.free(bts)
buf := unsafe.Slice((*byte)(bts), t.Size())
n, err := io.ReadFull(io.NewSectionReader(sr, int64(t.Offset), int64(t.Size())), buf)
if err != nil || n != len(buf) {
return errors.New("read failed")
}
if n != len(bts) {
return errors.New("short read")
}
C.ggml_backend_tensor_set(tt, unsafe.Pointer(&bts[0]), 0, C.size_t(t.Size()))
C.ggml_backend_tensor_set(tt, bts, 0, C.size_t(t.Size()))
return nil
})
}
@@ -371,7 +373,7 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&schedBufts[0])),
C.int(len(schedBackends)),
C.size_t(maxGraphNodes),
true,
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
),
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],

View File

@@ -1,207 +0,0 @@
package llama
import (
"fmt"
"math"
"strings"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
)
type Options struct {
hiddenSize, numHeads, numKVHeads, headDim int
eps, ropeBase, ropeScale float32
ropeDim uint32
}
type Model struct {
model.Base
model.BytePairEncoding
TokenEmbedding *nn.Embedding `gguf:"token_embd"`
Layers []Layer `gguf:"blk"`
OutputNorm *nn.RMSNorm `gguf:"output_norm"`
Output *nn.Linear `gguf:"output,alt:token_embd"`
*Options
}
func New(c ml.Config) (model.Model, error) {
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
}
m := Model{
BytePairEncoding: model.NewBytePairEncoding(
c.String("tokenizer.ggml.pretokenizer", `[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
&model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
Options: &Options{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
headDim: int(c.Uint("attention.key_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
},
}
fmt.Println("Model Parameters:")
fmt.Printf(" model_type: %q\n", "gpt2")
fmt.Printf(" vocab_size: %d\n", len(c.Strings("tokenizer.ggml.tokens")))
fmt.Printf(" hidden_size: %d\n", m.Options.hiddenSize)
fmt.Printf(" num_hidden_layers: %d\n", c.Uint("block_count"))
fmt.Printf(" num_attention_heads: %d\n", m.Options.numHeads)
fmt.Printf(" num_key_value_heads: %d\n", m.Options.numKVHeads)
fmt.Printf(" rms_norm_eps: %g\n", m.Options.eps)
fmt.Printf(" rope_theta: %g\n", m.Options.ropeBase)
fmt.Printf(" bos_token_id: %d\n", c.Uint("tokenizer.ggml.bos_token_id"))
fmt.Printf(" eos_token_id: %d\n", c.Uint("tokenizer.ggml.eos_token_id"))
fmt.Printf(" pad_token_id: %d\n", c.Uint("tokenizer.ggml.pad_token_id", 0))
m.Cache = kvcache.NewCausalCache(m.Shift)
return &m, nil
}
type SelfAttention struct {
Query *nn.Linear `gguf:"self_attn.q_proj"`
Key *nn.Linear `gguf:"self_attn.k_proj"`
Value *nn.Linear `gguf:"self_attn.v_proj"`
Output *nn.Linear `gguf:"self_attn.o_proj"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(0)
// Get head dimension - use explicit value if available, otherwise calculate
headDim := opts.headDim
if headDim == 0 {
headDim = opts.hiddenSize / opts.numHeads
}
// Query projection and reshape
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Key projection and reshape
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
// Value projection and reshape
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
// Attention computation
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
// Reshape attention output for final projection
outputDim := headDim * opts.numHeads
kqv = kqv.Reshape(ctx, outputDim, batchSize)
// Apply output projection
return sa.Output.Forward(ctx, kqv)
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
}
type MLP struct {
Up *nn.Linear `gguf:"mlp.up_proj"`
Down *nn.Linear `gguf:"mlp.down_proj"`
Gate *nn.Linear `gguf:"mlp.gate_proj"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type Layer struct {
AttentionNorm *nn.RMSNorm `gguf:"input_layernorm"`
SelfAttention *SelfAttention
MLPNorm *nn.RMSNorm `gguf:"post_attention_layernorm"`
MLP *MLP
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLPNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs))
if err != nil {
return nil, err
}
positions, err := ctx.Input().FromIntSlice(opts.Positions, len(opts.Positions))
if err != nil {
return nil, err
}
outputs, err := ctx.Output().FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil {
return nil, err
}
// Get token embeddings
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
// Apply output normalization
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
// Apply output projection
return m.Output.Forward(ctx, hiddenState), nil
}
func init() {
model.Register("mistral", New)
}

View File

@@ -4,6 +4,5 @@ import (
_ "github.com/ollama/ollama/model/models/gemma2"
_ "github.com/ollama/ollama/model/models/gemma3"
_ "github.com/ollama/ollama/model/models/llama"
_ "github.com/ollama/ollama/model/models/mistral"
_ "github.com/ollama/ollama/model/models/mllama"
)

View File

@@ -263,10 +263,6 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
continue
}
if id := bpe.vocab.Encode(pair.value); id < 0 {
continue
}
merges[pair.a].runes = append(left.runes, right.runes...)
merges[pair.b].runes = nil

View File

@@ -209,326 +209,6 @@ func TestLlama(t *testing.T) {
})
}
// tekken loads the Tekken tokenizer for testing
func tekken(t testing.TB) TextProcessor {
t.Helper()
// Load tokenizer config from mistral-small
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
configFile, err := os.Open(tokenizerConfigPath)
if err != nil {
t.Fatal(err)
}
defer configFile.Close()
var config struct {
AddBosToken bool `json:"add_bos_token"`
AddEosToken bool `json:"add_eos_token"`
BosToken struct {
Content string `json:"content"`
} `json:"bos_token"`
EosToken struct {
Content string `json:"content"`
} `json:"eos_token"`
}
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
t.Fatal(err)
}
// Load tokenizer.json which contains the vocabulary and other settings
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
tokenizerFile, err := os.Open(tokenizerJsonPath)
if err != nil {
t.Fatal(err)
}
defer tokenizerFile.Close()
var tokenizerData struct {
Model struct {
Type string `json:"type"`
Vocab map[string]int32 `json:"vocab"`
Merges []string `json:"merges"`
} `json:"model"`
AddedTokens []struct {
Id int32 `json:"id"`
Content string `json:"content"`
Special bool `json:"special"`
} `json:"added_tokens"`
PreTokenizer struct {
Type string `json:"type"`
Pretokenizers []struct {
Type string `json:"type"`
Pattern struct {
String string `json:"String"`
} `json:"pattern"`
Behavior string `json:"behavior"`
} `json:"pretokenizers"`
} `json:"pre_tokenizer"`
}
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
t.Fatal(err)
}
// Extract the pattern from pre_tokenizer if available
var pattern string
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
}
// Combine regular vocab and added tokens
vocab := tokenizerData.Model.Vocab
// Add special tokens from added_tokens
for _, token := range tokenizerData.AddedTokens {
vocab[token.Content] = token.Id
}
// Create vocabulary arrays
maxId := int32(-1)
for _, id := range vocab {
if id > maxId {
maxId = id
}
}
vocabSize := int(maxId + 1)
types := make([]uint32, vocabSize)
tokens := make([]string, vocabSize)
scores := make([]float32, vocabSize)
for token, id := range vocab {
tokens[id] = token
types[id] = TOKEN_TYPE_NORMAL
// Assign appropriate token types for special tokens
if token == "<s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "</s>" {
types[id] = TOKEN_TYPE_CONTROL
} else if token == "[INST]" || token == "[/INST]" {
types[id] = TOKEN_TYPE_CONTROL
}
}
// In Tekken, we don't need to load merges separately as they're part of the model
var merges []string
// Create vocabulary object
vocabObj := &Vocabulary{
Values: tokens,
Types: types,
Scores: scores,
Merges: merges,
BOS: vocab[config.BosToken.Content],
EOS: vocab[config.EosToken.Content],
AddBOS: config.AddBosToken,
AddEOS: config.AddEosToken,
}
// Use pattern from tokenizer.json if available
if pattern != "" {
// Ensure pattern has proper escaping for Go regexp
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
return NewBytePairEncoding(pattern, vocabObj)
}
// Fallback pattern if not found
return NewBytePairEncoding(
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
vocabObj,
)
}
func TestTekken(t *testing.T) {
// Skip if the test data isn't available
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
t.Skip("Mistral-small test data not available")
}
tokenizer := tekken(t)
t.Run("whitespace_handling", func(t *testing.T) {
t.Parallel()
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
cases := []struct {
input string
expected string
}{
{" hello", " hello"},
{"hello ", "hello "},
{"hello world", "hello world"},
{" hello world ", " hello world "},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
if decoded != tc.expected {
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
}
}
})
t.Run("chat_templates", func(t *testing.T) {
t.Parallel()
// Test the Tekken chat template format which doesn't have spaces after special tokens
templates := []struct {
input string
expectSpace bool // whether we expect a space after special tokens
}{
{"<s>[INST]user message[/INST]", false},
{"<s>[INST] user message[/INST]", true},
{"<s>[INST]user message [/INST]", true},
}
for _, tc := range templates {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
// Check if there's a space after special tokens
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
if hasSpaceAfterINST != tc.expectSpace {
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
hasSpaceAfterINST, tc.expectSpace, tc.input)
}
}
})
t.Run("special_tokens", func(t *testing.T) {
t.Parallel()
// Test how Tekken handles special tokens
cases := []struct {
input string
expected []string // We'll check if these tokens are in the decoded output
}{
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
}
for _, tc := range cases {
ids, err := tokenizer.Encode(tc.input, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", tc.input, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
continue
}
for _, expected := range tc.expected {
if !strings.Contains(decoded, expected) {
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
}
}
}
})
t.Run("vocabulary_coverage", func(t *testing.T) {
t.Parallel()
// Tekken has a larger vocabulary, so test coverage of various token types
samples := []string{
"Hello world!",
"This is a test of the Tekken tokenizer.",
"It has a considerably larger vocabulary size.",
"Special characters: !@#$%^&*()",
"Numbers: 1234567890",
"Multiple languages: こんにちは 你好 안녕하세요",
"Code snippets: def function(): return True",
}
for _, sample := range samples {
ids, err := tokenizer.Encode(sample, false)
if err != nil {
t.Errorf("Failed to encode %q: %v", sample, err)
continue
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
continue
}
if decoded != sample {
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
}
}
})
t.Run("splitting_behavior", func(t *testing.T) {
t.Parallel()
// Test the splitting behavior which might differ from SentencePiece
cases := map[string][]string{
"Hello World!": {"Hello", " World", "!"},
"user message": {"user", " message"},
"[INST]hello": {"[INST]", "hello"},
"hello[/INST]": {"hello", "[/INST]"},
}
for s, want := range cases {
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
}
}
})
t.Run("full_chat_sequence", func(t *testing.T) {
t.Parallel()
// Test a complete chat sequence with Tekken's format
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
ids, err := tokenizer.Encode(chatSequence, false)
if err != nil {
t.Fatalf("Failed to encode chat sequence: %v", err)
}
decoded, err := tokenizer.Decode(ids)
if err != nil {
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
}
// In Tekken, the whitespace shouldn't be added after special tokens
if strings.Contains(decoded, "[INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
}
if strings.Contains(decoded, "[/INST] ") {
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
}
})
}
func BenchmarkBytePairEncoding(b *testing.B) {
tokenizer := llama(b)
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))

View File

@@ -89,7 +89,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -107,11 +107,6 @@ func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*Inp
return nil, nil, err
}
// TODO (brucemacd): cachePrompt is always true for completion, but false for embedding, can this be improved?
if !cachePrompt {
numPast = 0
}
slot.InUse = true
slot.lastUsed = time.Now()

View File

@@ -297,3 +297,131 @@ func TestShiftDiscard(t *testing.T) {
})
}
}
func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
}{
{
name: "Basic cache hit - single user",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Basic cache hit - multi user",
cache: InputCache{
multiUserCache: true,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
},
{
name: "Exact match - leave one input",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
},
{
name: "No available slots",
cache: InputCache{
multiUserCache: false,
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
slot, remainingPrompt, err := tt.cache.LoadCacheSlot(tt.prompt)
// Check error state
if (err != nil) != tt.wantErr {
t.Errorf("LoadCacheSlot() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr {
return // Skip further checks if we expected an error
}
// Verify slot ID
if slot.Id != tt.expectedSlotId {
t.Errorf("LoadCacheSlot() slot ID = %v, expected %v", slot.Id, tt.expectedSlotId)
}
// Verify slot is now marked in use
if !slot.InUse {
t.Errorf("LoadCacheSlot() slot not marked InUse")
}
// Verify remaining prompt length
if len(remainingPrompt) != tt.expectedPrompt {
t.Errorf("LoadCacheSlot() remaining prompt length = %v, expected %v",
len(remainingPrompt), tt.expectedPrompt)
}
})
}
}

View File

@@ -115,6 +115,9 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
params.numKeep = int32(len(inputs))
}
// TODO(jessegross): We should ensure that we always leave minBatch of context space to shift,
// otherwise we might truncate or split the batch against the model's wishes
// Ensure that at least 1 input can be discarded during shift
params.numKeep = min(params.numKeep, s.cache.numCtx-1)
@@ -179,10 +182,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, *
return nil, nil, err
}
for _, t := range tokens {
decoded, _ := s.model.(model.TextProcessor).Decode([]int32{t})
fmt.Println("token", t, "decoded", decoded)
}
for _, t := range tokens {
inputs = append(inputs, input.Input{Token: t})
}
@@ -370,17 +369,6 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize
for j, inp := range seq.inputs {
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
} else {
break
}
}
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
@@ -393,6 +381,20 @@ func (s *Server) processBatch() error {
break
}
// If the sum of our working set (already processed tokens, tokens we added to this
// batch, required following tokens) exceeds the context size, then trigger a shift
// now so we don't have to do one later when we can't break the batch.
if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+minBatch) > s.cache.numCtx {
if len(seq.pendingInputs) != 0 {
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
}
}
options.Inputs = append(options.Inputs, inp.Token)
if inp.Multimodal != nil {
options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal})
@@ -594,7 +596,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
found := false
for i, sq := range s.seqs {
if sq == nil {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)

View File

@@ -87,8 +87,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
// topK also sorts the tokens in descending order of logits
tokens = topK(tokens, s.topK)
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
// scale and normalize the tokens in place
temperature(tokens, s.temperature)
softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)

View File

@@ -26,17 +26,16 @@ func (h *tokenHeap) Pop() any {
}
// temperature applies scaling to the logits
func temperature(ts []token, temp float32) []token {
func temperature(ts []token, temp float32) {
// Ensure temperature clipping near 0 to avoid numerical instability
temp = max(temp, 1e-7)
for i := range ts {
ts[i].value = ts[i].value / temp
}
return ts
}
// softmax applies normalization to the logits
func softmax(ts []token) []token {
func softmax(ts []token) {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
@@ -56,8 +55,6 @@ func softmax(ts []token) []token {
for i := range ts {
ts[i].value /= sum
}
return ts
}
// topK limits the number of tokens considered to the k highest logits
@@ -99,6 +96,7 @@ func topK(ts []token, k int) []token {
}
// topP limits tokens to those with cumulative probability p
// requires ts to be sorted in descending order of probabilities
func topP(ts []token, p float32) []token {
if p == 1.0 {
return ts
@@ -109,37 +107,24 @@ func topP(ts []token, p float32) []token {
for i, t := range ts {
sum += t.value
if sum > float32(p) {
ts = ts[:i+1]
return ts
return ts[:i+1]
}
}
return ts
}
// minP limits tokens to those with cumulative probability p
// minP filters tokens with probabilities >= p * max_prob
// requires ts to be sorted in descending order of probabilities
func minP(ts []token, p float32) []token {
if p == 1.0 {
return ts
}
maxProb := ts[0].value
maxProb := float32(math.Inf(-1))
for _, token := range ts {
if token.value > maxProb {
maxProb = token.value
threshold := maxProb * p
for i, t := range ts {
if t.value < threshold {
return ts[:i]
}
}
threshold := maxProb * float32(p)
// Filter tokens in-place
validTokens := ts[:0]
for i, token := range ts {
if token.value >= threshold {
validTokens = append(validTokens, ts[i])
}
}
ts = validTokens
return ts
}

View File

@@ -34,17 +34,22 @@ func compareLogits(t *testing.T, name string, want []float32, got []token) {
func TestTemperature(t *testing.T) {
input := []float32{1.0, 4.0, -2.0, 0.0}
got := temperature(toTokens(input), 0.5)
tokens := toTokens(input)
temperature(tokens, 0.5)
want := []float32{2.0, 8.0, -4.0, 0.0}
compareLogits(t, "temperature(0.5)", want, got)
compareLogits(t, "temperature(0.5)", want, tokens)
got = temperature(toTokens(input), 1.0)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 1.0)
want = []float32{1.0, 4.0, -2.0, 0.0}
compareLogits(t, "temperature(1)", want, got)
compareLogits(t, "temperature(1)", want, tokens)
got = temperature(toTokens(input), 0.0)
input = []float32{1.0, 4.0, -2.0, 0.0}
tokens = toTokens(input)
temperature(tokens, 0.0)
want = []float32{1e7, 4e7, -2e7, 0.0}
compareLogits(t, "temperature(0)", want, got)
compareLogits(t, "temperature(0)", want, tokens)
}
func TestSoftmax(t *testing.T) {
@@ -90,16 +95,17 @@ func TestSoftmax(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := softmax(toTokens(tt.input))
tokens := toTokens(tt.input)
softmax(tokens)
if tt.expected != nil {
compareLogits(t, tt.name, tt.expected, got)
compareLogits(t, tt.name, tt.expected, tokens)
return
}
// Check probabilities sum to 1
var sum float32
for _, token := range got {
for _, token := range tokens {
sum += token.value
if token.value < 0 || token.value > 1 {
t.Errorf("probability out of range [0,1]: got %f", token.value)
@@ -114,38 +120,44 @@ func TestSoftmax(t *testing.T) {
func TestTopK(t *testing.T) {
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
// Test k=5
got := topK(toTokens(input), 5)
if len(got) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(got))
tokens := toTokens(input)
tokens = topK(tokens, 5)
if len(tokens) != 5 {
t.Errorf("topK(5): wrong length: want 5, got %d", len(tokens))
}
// Should keep highest 3 values in descending order
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154}
compareLogits(t, "topK(3)", want, got)
compareLogits(t, "topK(3)", want, tokens)
got = topK(toTokens(input), 20)
if len(got) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, 20)
if len(tokens) != len(input) {
t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(tokens))
}
// Test k=-1
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), -1)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, -1)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, got)
compareLogits(t, "topK(-1)", want, tokens)
// Test k=0
input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
got = topK(toTokens(input), 0)
if len(got) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got))
tokens = toTokens(input)
tokens = topK(tokens, 0)
if len(tokens) != len(input) {
t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(tokens))
}
compareLogits(t, "topK(-1)", want, tokens)
input = []float32{-1e7, -2e7, -3e7, -4e7}
tokens = toTokens(input)
tokens = topK(tokens, 1)
if len(tokens) < 1 {
t.Error("topK should keep at least one token")
}
compareLogits(t, "topK(-1)", want, got)
}
func TestTopP(t *testing.T) {
@@ -153,16 +165,25 @@ func TestTopP(t *testing.T) {
tokens := toTokens(input)
// First apply temperature and softmax to get probabilities
tokens = softmax(tokens)
softmax(tokens)
tokens = topK(tokens, 20)
// Then apply topP
got := topP(tokens, 0.95)
tokens = topP(tokens, 0.95)
// Should keep tokens until cumsum > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(got))
t.Logf("got: %v", got)
if len(tokens) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e6} // One dominant token
tokens = toTokens(input)
softmax(tokens)
tokens = topP(tokens, 0.0) // Very small p
if len(tokens) < 1 {
t.Error("topP should keep at least one token")
}
}
@@ -171,14 +192,45 @@ func TestMinP(t *testing.T) {
tokens := toTokens(input)
// First apply temperature and softmax
tokens = softmax(tokens)
tokens = topK(tokens, 20)
softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)
tokens = minP(tokens, 1.0)
if len(tokens) != 1 {
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(tokens), len(tokens))
}
// Test with normal p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.2)
// Should keep tokens with prob >= 0.2 * max_prob
if len(got) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
if len(tokens) > 3 {
t.Errorf("minP(0.2): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", tokens)
}
// Test with zero p value
tokens = toTokens(input) // Reset tokens
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.0)
// Should keep only the highest probability token
if len(tokens) != len(input) {
t.Errorf("minP(0.0): should keep only one token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)
tokens = minP(tokens, 1.0)
if len(tokens) < 1 {
t.Error("minP should keep at least one token even with extreme probabilities")
}
}
@@ -231,7 +283,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 10)
tokens = topK(tokensCopy, 10)
}
})
@@ -239,7 +291,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topP(tokensCopy, 0.9)
tokens = topP(tokensCopy, 0.9)
}
})
@@ -247,7 +299,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
minP(tokensCopy, 0.2)
tokens = minP(tokensCopy, 0.2)
}
})
@@ -255,7 +307,7 @@ func BenchmarkTransforms(b *testing.B) {
b.ResetTimer()
for b.Loop() {
copy(tokensCopy, tokens)
topK(tokensCopy, 200000)
tokens = topK(tokensCopy, 200000)
}
})
}