149 lines
3.8 KiB
Go
149 lines
3.8 KiB
Go
package main
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
|
|
"github.com/ollama/ollama/api"
|
|
"github.com/ollama/ollama/ml"
|
|
"github.com/ollama/ollama/model"
|
|
"github.com/ollama/ollama/model/input"
|
|
_ "github.com/ollama/ollama/model/models" // Register all models
|
|
"github.com/ollama/ollama/model/renderers"
|
|
"github.com/ollama/ollama/sample"
|
|
)
|
|
|
|
func main() {
|
|
modelPath := "/Users/parth/.ollama/models/blobs/sha256-a87e10578f328b087f888ac7bd1018555e26028a1130980f20312b4de3a10d70"
|
|
|
|
fmt.Println("Loading OLMo model...")
|
|
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Println("✅ Model loaded successfully!")
|
|
|
|
// Initialize the cache
|
|
cache := m.Config().Cache
|
|
if cache != nil {
|
|
// Initialize with reasonable defaults:
|
|
// - dtype: F16
|
|
// - maxSequences: 1 (single sequence)
|
|
// - capacity: 2048 (context length)
|
|
// - maxBatch: 512
|
|
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
|
|
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
|
|
}
|
|
|
|
// Use the olmo3 renderer to format the prompt properly
|
|
messages := []api.Message{
|
|
{Role: "user", Content: "wagwan"},
|
|
}
|
|
// prompt := "Question: What is machine learning? Answer:"
|
|
prompt, err := renderers.RenderWithRenderer("olmo3", messages, nil, nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
// prompt = prompt[:len(prompt)]
|
|
// prompt := "Question: What is machine learning? Answer:"
|
|
fmt.Printf("\nRendered prompt:\n%s\n", prompt)
|
|
|
|
tp := m.(model.TextProcessor)
|
|
tokens, err := tp.Encode(prompt, false)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
|
|
|
|
// Generate 20 tokens
|
|
maxTokens := 20
|
|
generated := make([]int32, 0, maxTokens)
|
|
|
|
// Create sampler (temperature=0 for greedy sampling)
|
|
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
|
|
|
|
for i := 0; i < maxTokens; i++ {
|
|
// Create a new context for each generation step to avoid memory buildup
|
|
ctx := m.Backend().NewContext()
|
|
|
|
var inputTokens []int32
|
|
var positions []int32
|
|
|
|
if i == 0 {
|
|
// First iteration: process all prompt tokens
|
|
inputTokens = tokens
|
|
positions = make([]int32, len(tokens))
|
|
for j := range positions {
|
|
positions[j] = int32(j)
|
|
}
|
|
} else {
|
|
// Subsequent iterations: only process the newly generated token
|
|
// The last token is at position len(tokens)-1 (its index in the sequence)
|
|
inputTokens = []int32{tokens[len(tokens)-1]}
|
|
positions = []int32{int32(len(tokens) - 1)}
|
|
}
|
|
|
|
sequences := make([]int, len(inputTokens))
|
|
// All tokens belong to sequence 0
|
|
|
|
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
|
|
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
|
|
|
|
batch := input.Batch{
|
|
Inputs: inputsTensor,
|
|
Positions: positions,
|
|
Sequences: sequences,
|
|
Outputs: outputs,
|
|
}
|
|
|
|
// Forward pass (model.Forward handles cache.StartForward internally)
|
|
logits, err := model.Forward(ctx, m, batch)
|
|
if err != nil {
|
|
ctx.Close()
|
|
log.Fatal(err)
|
|
}
|
|
|
|
logits = logits.Contiguous(ctx)
|
|
ctx.Forward(logits).Compute(logits)
|
|
|
|
logitValues := logits.Floats()
|
|
|
|
// Sample next token
|
|
nextToken, err := sampler.Sample(logitValues)
|
|
if err != nil {
|
|
ctx.Close()
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Close context before moving to next iteration
|
|
ctx.Close()
|
|
|
|
generated = append(generated, nextToken)
|
|
tokens = append(tokens, nextToken)
|
|
|
|
// Decode and print
|
|
decoded, _ := tp.Decode([]int32{nextToken})
|
|
fmt.Print(decoded)
|
|
|
|
// Stop on EOS or <|im_end|>
|
|
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
|
|
break
|
|
}
|
|
// Check if we generated <|im_end|> (stop token for chat)
|
|
if decoded == "<|im_end|>" {
|
|
break
|
|
}
|
|
}
|
|
|
|
fmt.Println("\n\n✅ Generation completed!")
|
|
fullText, _ := tp.Decode(generated)
|
|
fmt.Printf("Generated: %s\n", fullText)
|
|
}
|