test + model

This commit is contained in:
ParthSareen 2025-12-09 18:28:29 -08:00
parent 5d50848c52
commit 3015146cda
2 changed files with 629 additions and 82 deletions

View File

@ -28,8 +28,6 @@ type Options struct {
ropeType string
ropeExtrapolation float32
ropeBetaFast float32
ropeBetaSlow float32
slidingWindowPattern []bool
}
@ -52,7 +50,7 @@ func New(c fs.Config) (model.Model, error) {
Scores: c.Floats("tokenizer.ggml.scores"),
Types: c.Ints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false),
BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))},
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
EOS: append(
@ -85,8 +83,6 @@ func New(c fs.Config) (model.Model, error) {
attnFactor := c.Float("rope.scaling.attn_factor", 1)
ropeType := c.String("rope.scaling.type")
ropeExtrapolation := c.Float("rope.scaling.extrapolation_factor", 1.0)
ropeBetaFast := c.Float("rope.scaling.beta_fast", 64.0)
ropeBetaSlow := c.Float("rope.scaling.beta_slow", 1.0)
fmt.Printf("hiddenSize: %d\n", hiddenSize)
fmt.Printf("numHeads: %d\n", numHeads)
@ -100,8 +96,6 @@ func New(c fs.Config) (model.Model, error) {
fmt.Printf("attnFactor: %f\n", attnFactor)
fmt.Printf("ropeType: %s\n", ropeType)
fmt.Printf("ropeExtrapolation: %f\n", ropeExtrapolation)
fmt.Printf("ropeBetaFast: %f\n", ropeBetaFast)
fmt.Printf("ropeBetaSlow: %f\n", ropeBetaSlow)
fmt.Printf("sliding_window_pattern: %v\n", c.Bools("attention.sliding_window_pattern"))
m := Model{
@ -120,14 +114,14 @@ func New(c fs.Config) (model.Model, error) {
attnFactor: attnFactor,
ropeType: ropeType,
ropeExtrapolation: ropeExtrapolation,
ropeBetaFast: ropeBetaFast,
ropeBetaSlow: ropeBetaSlow,
slidingWindowPattern: c.Bools("attention.sliding_window_pattern"),
},
}
m.Cache = kvcache.NewWrapperCache(kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift), kvcache.NewCausalCache(m.Shift))
// m.Cache = kvcache.NewCausalCache(m.Shift)
m.Cache = kvcache.NewWrapperCache(
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
kvcache.NewCausalCache(m.Shift),
)
return &m, nil
}
@ -142,65 +136,59 @@ type SelfAttention struct {
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (o *Options) ropeOptions(factors ml.Tensor, isSWA bool) []func(*rope.Options) {
opts := []func(*rope.Options){
rope.WithFactors(factors),
}
func (m *Model) applyRoPE(ctx ml.Context, states, positions ml.Tensor, ropeDim int, isSWA bool) ml.Tensor {
if !isSWA && o.originalContextLength > 0 {
// opts = append(opts,
// rope.WithOriginalContextLength(o.originalContextLength),
// rope.WithAttentionFactor(o.attnFactor),
var ropeOpts []func(*rope.Options)
// Both SWA and non-SWA use beta_fast and beta_slow
// But SWA uses freq_scale=1.0, ext_factor=0.0, attn_factor=1.0
// Non-SWA uses full yarn parameters
if m.originalContextLength > 0 {
ropeOpts = append(ropeOpts,
rope.WithOriginalContextLength(m.originalContextLength),
)
// if !isSWA {
ropeOpts = append(ropeOpts,
rope.WithExtrapolationFactor(m.ropeExtrapolation),
)
// rope.WithAttentionFactor(m.attnFactor),
// )
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(o.ropeExtrapolation),
rope.WithAttentionFactor(o.attnFactor),
rope.WithBetaFast(o.ropeBetaFast),
rope.WithBetaSlow(o.ropeBetaSlow),
)
} else if isSWA && o.originalContextLength > 0 {
opts = append(opts,
rope.WithOriginalContextLength(o.originalContextLength),
rope.WithExtrapolationFactor(0.),
rope.WithAttentionFactor(1.),
)
}
return opts
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeDim := headDim
query := sa.Query.Forward(ctx, hiddenState)
if sa.QNorm != nil {
query = sa.QNorm.Forward(ctx, query, opts.eps)
}
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key := sa.Key.Forward(ctx, hiddenState)
if sa.KNorm != nil {
key = sa.KNorm.Forward(ctx, key, opts.eps)
}
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / opts.ropeScale
freqScale = 1. / m.ropeScale
}
ropeOpts := opts.ropeOptions(sa.RopeFactors, isSWA)
query = nn.RoPE(ctx, query, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
key = nn.RoPE(ctx, key, positions, ropeDim, opts.ropeBase, freqScale, ropeOpts...)
return nn.RoPE(ctx, states, positions, ropeDim, m.ropeBase, freqScale, ropeOpts...)
}
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := m.hiddenSize / m.numHeads
ropeDim := headDim
query := sa.Query.Forward(ctx, hiddenState)
// double check type
query = sa.QNorm.Forward(ctx, query, m.eps)
query = query.Reshape(ctx, headDim, m.numHeads, batchSize)
//check here
query = m.applyRoPE(ctx, query, positions, ropeDim, isSWA)
// and here
key := sa.Key.Forward(ctx, hiddenState)
key = sa.KNorm.Forward(ctx, key, m.eps)
key = key.Reshape(ctx, headDim, m.numKVHeads, batchSize)
key = m.applyRoPE(ctx, key, positions, ropeDim, isSWA)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, m.numKVHeads, batchSize)
// check attention scaling as well
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
attention = attention.Reshape(ctx, m.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
@ -208,14 +196,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positions ml.Tenso
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeDim := m.hiddenSize / m.numHeads
isSWA := m.isSWALayer(layer)
freqScale := float32(1.0)
if !isSWA {
freqScale = 1. / m.ropeScale
}
ropeOpts := m.Options.ropeOptions(m.Layers[layer].SelfAttention.RopeFactors, isSWA)
return nn.RoPE(ctx, key, shift, ropeDim, m.ropeBase, freqScale, ropeOpts...), nil
return m.applyRoPE(ctx, key, shift, ropeDim, isSWA), nil
}
type MLP struct {
@ -224,7 +205,7 @@ type MLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
}
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *Options) ml.Tensor {
func (mlp *MLP) Forward(ctx ml.Context, hiddenState ml.Tensor, m *Model) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).SILU(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
@ -236,13 +217,11 @@ type Layer struct {
PostFFWNorm *nn.RMSNorm `gguf:"post_ffw_norm"`
}
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, opts *Options, isSWA bool) ml.Tensor {
func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tensor, cache kvcache.Cache, m *Model, isSWA bool) ml.Tensor {
residual := hiddenState
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, opts, isSWA)
if l.PostAttentionNorm != nil {
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, opts.eps)
}
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positions, cache, m, isSWA)
hiddenState = l.PostAttentionNorm.Forward(ctx, hiddenState, m.eps)
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
@ -251,8 +230,9 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positions, outputs ml.Tenso
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
hiddenState = l.MLP.Forward(ctx, hiddenState, opts)
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.MLP.Forward(ctx, hiddenState, m)
hiddenState = l.PostFFWNorm.Forward(ctx, hiddenState, m.eps)
return hiddenState.Add(ctx, residual)
}
@ -266,7 +246,6 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
positions := ctx.Input().FromInts(batch.Positions, len(batch.Positions))
hiddenState := m.TokenEmbedding.Forward(ctx, batch.Inputs)
hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.hiddenSize)))
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
@ -277,10 +256,10 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
cacheType = cacheTypeCausal
}
if wc, ok := m.Cache.(*kvcache.WrapperCache); ok {
wc.SetLayerType(cacheType)
}
if causal, ok := m.Cache.(*kvcache.Causal); ok {
cache := m.Cache.(*kvcache.WrapperCache)
cache.SetLayerType(cacheType)
// would need to check the cache at the layer instead
if causal, ok := cache.UnderlyingCache().(*kvcache.Causal); ok {
causal.SetCausal(ctx, kvcache.CausalOptions{Except: []int{i}})
}
@ -289,7 +268,7 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
outputs = batch.Outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, &m.Options, isSWA)
hiddenState = layer.Forward(ctx, hiddenState, positions, outputs, m.Cache, m, isSWA)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)

View File

@ -0,0 +1,568 @@
package olmo
import (
"encoding/binary"
"encoding/json"
"flag"
"fmt"
"log/slog"
"math"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
typemodel "github.com/ollama/ollama/types/model"
)
var args struct {
model,
prompt string
layers int
}
func TestMain(m *testing.M) {
flag.StringVar(&args.model, "model", "", "path to model (e.g., olmo3:latest)")
flag.StringVar(&args.prompt, "prompt", "Hello, how are", "model prompt")
flag.IntVar(&args.layers, "layers", math.MaxInt, "num of gpu layers")
flag.Parse()
slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel()))
os.Exit(m.Run())
}
func blob(tb testing.TB, modelName string) string {
tb.Helper()
models := envconfig.Models()
manifest, err := os.Open(filepath.Join(models, "manifests", typemodel.ParseName(modelName).Filepath()))
if err != nil {
tb.Fatal(err)
}
defer manifest.Close()
var m struct {
Layers []struct {
MediaType string `json:"mediaType"`
Digest string `json:"digest"`
} `json:"layers"`
}
if err := json.NewDecoder(manifest).Decode(&m); err != nil {
tb.Fatal(err)
}
for _, layer := range m.Layers {
if layer.MediaType == "application/vnd.ollama.image.model" {
tb.Log("using model blob", layer.Digest)
return filepath.Join(models, "blobs", strings.ReplaceAll(layer.Digest, ":", "-"))
}
}
tb.Fatal("model blob not found")
return ""
}
func loadFloatsFromBinary(filename string) ([]float32, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil {
return nil, err
}
if fi.Size()%4 != 0 {
return nil, fmt.Errorf("file size %d not multiple of 4", fi.Size())
}
n := int(fi.Size() / 4)
floats := make([]float32, n)
if err := binary.Read(f, binary.LittleEndian, floats); err != nil {
return nil, err
}
return floats, nil
}
func TestTokenization(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
prompt := args.prompt
if prompt == "" {
prompt = "hello"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
t.Logf("num tokens: %d", len(tokens))
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
}
func TestAttentionForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
t.Logf("Model options: hiddenSize=%d, numHeads=%d, numKVHeads=%d",
olmoModel.hiddenSize, olmoModel.numHeads, olmoModel.numKVHeads)
t.Logf("Layer 0 attention: %+v", olmoModel.Layers[0].SelfAttention)
ctx := m.Backend().NewContext()
// Create test hidden states: (hiddenSize, batchSize)
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0 // Simple test values
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
// Test attention forward (without cache for simplicity)
attentionBlock := olmoModel.Layers[0].SelfAttention
isSWA := olmoModel.isSWALayer(0)
t.Logf("Layer 0 isSWA: %v", isSWA)
result := attentionBlock.Forward(ctx, hiddenStates, positions, nil, olmoModel, isSWA)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Attention result shape: %v dtype: %v", result.Shape(), result.DType())
// Optionally dump to file
// if err := os.WriteFile("/tmp/olmo_attention_output.bin", result.Bytes(), 0644); err != nil {
// t.Fatal(err)
// }
}
func TestMLPForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
ctx := m.Backend().NewContext()
// Create test hidden states
batchSize := 4
hiddenSize := olmoModel.hiddenSize
hsFloats := make([]float32, hiddenSize*batchSize)
for i := range hsFloats {
hsFloats[i] = float32(i%100) / 100.0
}
hiddenStates := ctx.Input().FromFloats(hsFloats, hiddenSize, batchSize)
t.Logf("hiddenStates shape: %v", hiddenStates.Shape())
mlpBlock := olmoModel.Layers[0].MLP
result := mlpBlock.Forward(ctx, hiddenStates, olmoModel)
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("MLP result shape: %v dtype: %v", result.Shape(), result.DType())
// Parse result bytes to float32
resultBytes := result.Bytes()
resultFloats := make([]float32, len(resultBytes)/4)
for i := range resultFloats {
bits := binary.LittleEndian.Uint32(resultBytes[i*4 : (i+1)*4])
resultFloats[i] = math.Float32frombits(bits)
}
// Compute statistics
var minVal, maxVal, sum float32
minVal = resultFloats[0]
maxVal = resultFloats[0]
for _, v := range resultFloats {
if v < minVal {
minVal = v
}
if v > maxVal {
maxVal = v
}
sum += v
}
mean := sum / float32(len(resultFloats))
// Build readable output
var sb strings.Builder
sb.WriteString("# MLP Forward Output\n\n")
sb.WriteString(fmt.Sprintf("# Input Shape: [%d, %d] (hiddenSize, batchSize)\n", hiddenSize, batchSize))
sb.WriteString(fmt.Sprintf("# Output Shape: %v\n", result.Shape()))
sb.WriteString(fmt.Sprintf("# DType: %v\n", result.DType()))
sb.WriteString(fmt.Sprintf("# Layer: 0\n\n"))
sb.WriteString("## Statistics\n\n")
sb.WriteString(fmt.Sprintf(" Total elements: %d\n", len(resultFloats)))
sb.WriteString(fmt.Sprintf(" Min: %v\n", minVal))
sb.WriteString(fmt.Sprintf(" Max: %v\n", maxVal))
sb.WriteString(fmt.Sprintf(" Mean: %v\n\n", mean))
sb.WriteString("## Input Hidden States (first 20 values)\n\n")
sb.WriteString(" [")
for i := 0; i < min(20, len(hsFloats)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", hsFloats[i]))
}
sb.WriteString("]\n\n")
sb.WriteString("## Output Values\n\n")
// Per-position output (each position in batch)
for pos := 0; pos < batchSize; pos++ {
sb.WriteString(fmt.Sprintf("Position %d (hiddenSize=%d values):\n", pos, hiddenSize))
// Extract values for this position
posStart := pos * hiddenSize
posEnd := posStart + hiddenSize
if posEnd > len(resultFloats) {
posEnd = len(resultFloats)
}
posValues := resultFloats[posStart:posEnd]
// Full tensor values
sb.WriteString(" [")
for i, v := range posValues {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", v))
}
sb.WriteString("]\n\n")
}
// Save to file
if err := os.WriteFile("/tmp/olmo_mlp_forward.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.txt")
// Also save binary
if err := os.WriteFile("/tmp/olmo_mlp_forward.bin", resultBytes, 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_mlp_forward.bin")
// Print summary to console
fmt.Println(sb.String())
}
func TestFullForward(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
ctx := m.Backend().NewContext()
prompt := args.prompt
if prompt == "" {
prompt = "Hello, how are you?"
}
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, false)
if err != nil {
t.Fatal(err)
}
t.Logf("prompt: %q", prompt)
t.Logf("tokens: %v", tokens)
decoded, err := tp.Decode(tokens)
if err != nil {
t.Fatal(err)
}
t.Logf("decoded: %q", decoded)
seqLen := len(tokens)
inputsTensor := ctx.Input().FromInts(tokens, seqLen)
positions := make([]int32, seqLen)
sequences := make([]int, seqLen)
for i := range tokens {
positions[i] = int32(i)
sequences[i] = 0
}
// Output ALL positions
outputIndices := make([]int32, seqLen)
for i := range outputIndices {
outputIndices[i] = int32(i)
}
outputs := ctx.Input().FromInts(outputIndices, seqLen)
batch := input.Batch{
Inputs: inputsTensor,
Positions: positions,
Sequences: sequences,
Outputs: outputs,
}
// Initialize cache
if cache := m.Config().Cache; cache != nil {
cache.Init(m.Backend(), ml.DTypeF16, 1, 4096, seqLen)
}
result, err := model.Forward(ctx, m, batch)
if err != nil {
t.Fatal(err)
}
result = result.Contiguous(ctx)
ctx.Forward(result).Compute(result)
t.Logf("Forward pass completed, result shape: %v", result.Shape())
// Dump logits to binary file
if err := os.WriteFile("/tmp/olmo_logits.bin", result.Bytes(), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.bin")
// Parse logits from bytes for detailed analysis
logitsBytes := result.Bytes()
vocabSize := result.Shape()[0]
// Read float32 values - shape is (vocab_size, seq_len)
allLogits := make([]float32, len(logitsBytes)/4)
for i := range allLogits {
bits := binary.LittleEndian.Uint32(logitsBytes[i*4 : (i+1)*4])
allLogits[i] = math.Float32frombits(bits)
}
// Create detailed text dump matching Python format
var sb strings.Builder
sb.WriteString("# Full Forward Logits\n\n")
sb.WriteString(fmt.Sprintf("# Shape: [1, %d, %d]\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Layout: (batch=1, seq_len=%d, vocab_size=%d)\n", seqLen, vocabSize))
sb.WriteString(fmt.Sprintf("# Prompt: '%s'\n", prompt))
sb.WriteString(fmt.Sprintf("# Tokens: %v\n\n", tokens))
type logitPair struct {
tokenID int
value float32
}
// Process each position
for pos := 0; pos < seqLen; pos++ {
// Extract logits for this position
// Shape is (vocab_size, seq_len), so logits[v*seqLen + pos] gives logit for vocab v at position pos
posLogits := make([]float32, vocabSize)
for v := 0; v < vocabSize; v++ {
posLogits[v] = allLogits[v*seqLen+pos]
}
// Find top 10 logits
pairs := make([]logitPair, len(posLogits))
for i, v := range posLogits {
pairs[i] = logitPair{tokenID: i, value: v}
}
// Sort by value descending (simple bubble sort for small top-k)
for i := 0; i < min(10, len(pairs)); i++ {
for j := i + 1; j < len(pairs); j++ {
if pairs[j].value > pairs[i].value {
pairs[i], pairs[j] = pairs[j], pairs[i]
}
}
}
tokenStr, _ := tp.Decode([]int32{tokens[pos]})
sb.WriteString(fmt.Sprintf("Position %d (token_id=%d, token='%s'):\n", pos, tokens[pos], tokenStr))
sb.WriteString(" Top 10 logits:\n")
for i := 0; i < min(10, len(pairs)); i++ {
tokStr, _ := tp.Decode([]int32{int32(pairs[i].tokenID)})
// Pad token string to 20 chars for alignment
paddedTok := fmt.Sprintf("%-20s", fmt.Sprintf("'%s'", tokStr))
sb.WriteString(fmt.Sprintf(" %d. token_id=%6d (%s): %f\n", i+1, pairs[i].tokenID, paddedTok, pairs[i].value))
}
// First 20 logits
sb.WriteString(" Full logits (first 20): [")
for i := 0; i < min(20, len(posLogits)); i++ {
if i > 0 {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n")
// Last 20 logits
sb.WriteString(" Full logits (last 20): [")
start := max(0, len(posLogits)-20)
for i := start; i < len(posLogits); i++ {
if i > start {
sb.WriteString(", ")
}
sb.WriteString(fmt.Sprintf("%v", posLogits[i]))
}
sb.WriteString("]\n\n")
}
if err := os.WriteFile("/tmp/olmo_logits.txt", []byte(sb.String()), 0644); err != nil {
t.Fatal(err)
}
t.Log("Saved /tmp/olmo_logits.txt")
// Print to console as well
fmt.Println(sb.String())
}
func TestRoPE(t *testing.T) {
if args.model == "" {
t.Skip("no model specified, use -model flag")
}
m, err := model.New(blob(t, args.model), ml.BackendParams{AllocMemory: true})
if err != nil {
t.Fatal(err)
}
if err := m.Backend().Load(t.Context(), func(float32) {}); err != nil {
t.Fatal(err)
}
olmoModel := m.(*Model)
// Test RoPE on a simple tensor
headDim := olmoModel.hiddenSize / olmoModel.numHeads
batchSize := 4
numHeads := olmoModel.numHeads
t.Logf("headDim: %d, numHeads: %d", headDim, numHeads)
t.Logf("ropeBase: %f, ropeScale: %f, originalContextLength: %d",
olmoModel.ropeBase, olmoModel.ropeScale, olmoModel.originalContextLength)
// Create test query tensor: (headDim, numHeads, batchSize)
queryFloats := make([]float32, headDim*numHeads*batchSize)
for i := range queryFloats {
queryFloats[i] = float32(i%100) / 100.0
}
// Test 1: Dump initial query values (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
t.Logf("query shape: %v", query.Shape())
query = query.Contiguous(ctx)
ctx.Forward(query).Compute(query)
dump := ml.Dump(ctx, query, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query BEFORE RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\n\n", query.Shape(), query.DType())
if err := os.WriteFile("/tmp/olmo_query_before_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_before_rope.bin", query.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_before_rope.txt and .bin")
}
// Test 2: SWA RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultSWA := olmoModel.applyRoPE(ctx, query, positions, headDim, true)
resultSWA = resultSWA.Contiguous(ctx)
ctx.Forward(resultSWA).Compute(resultSWA)
t.Logf("SWA RoPE result shape: %v", resultSWA.Shape())
dump := ml.Dump(ctx, resultSWA, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER SWA RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: 1.0 (SWA)\n\n", resultSWA.Shape(), resultSWA.DType())
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_swa_rope.bin", resultSWA.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_swa_rope.txt and .bin")
}
// Test 3: Global (non-SWA) RoPE (fresh context)
{
ctx := m.Backend().NewContext()
query := ctx.Input().FromFloats(queryFloats, headDim, numHeads, batchSize)
positions := ctx.Input().FromInts([]int32{0, 1, 2, 3}, batchSize)
resultGlobal := olmoModel.applyRoPE(ctx, query, positions, headDim, false)
resultGlobal = resultGlobal.Contiguous(ctx)
ctx.Forward(resultGlobal).Compute(resultGlobal)
t.Logf("Global RoPE result shape: %v", resultGlobal.Shape())
dump := ml.Dump(ctx, resultGlobal, ml.DumpWithPrecision(6), ml.DumpWithThreshold(1000000))
t.Logf("Query AFTER Global RoPE sample values: %s", dump[:min(500, len(dump))])
// Write to file
header := fmt.Sprintf("Shape: %v\nDType: %v\nfreqScale: %f (Global, 1/ropeScale)\n\n",
resultGlobal.Shape(), resultGlobal.DType(), 1.0/olmoModel.ropeScale)
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.txt", []byte(header+dump), 0644); err != nil {
t.Errorf("Failed to write file: %v", err)
}
if err := os.WriteFile("/tmp/olmo_query_after_global_rope.bin", resultGlobal.Bytes(), 0644); err != nil {
t.Errorf("Failed to write binary file: %v", err)
}
t.Log("Wrote /tmp/olmo_query_after_global_rope.txt and .bin")
}
}