This commit is contained in:
ParthSareen 2025-12-08 14:42:18 -08:00
parent dd3306d3a0
commit f475cc365a
4 changed files with 138 additions and 4 deletions

View File

@ -2,6 +2,7 @@ package olmo
import (
"cmp"
"fmt"
"math"
"github.com/ollama/ollama/fs"
@ -37,6 +38,7 @@ type Model struct {
}
func New(c fs.Config) (model.Model, error) {
fmt.Println("🦙 OLMo model loaded!")
vocabulary := model.Vocabulary{
Values: c.Strings("tokenizer.ggml.tokens"),
Scores: c.Floats("tokenizer.ggml.scores"),

View File

@ -0,0 +1,132 @@
package olmo
import (
"context"
"fmt"
"log"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/input"
"github.com/ollama/ollama/sample"
)
func main() {
modelPath := "/Users/nicole/models/Olmo-3-7B-Think/olmo-3-7b-think-q8_0.gguf"
fmt.Println("Loading OLMo model...")
m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true})
if err != nil {
log.Fatal(err)
}
if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil {
log.Fatal(err)
}
fmt.Println("✅ Model loaded successfully!")
// Initialize the cache
cache := m.Config().Cache
if cache != nil {
// Initialize with reasonable defaults:
// - dtype: F16
// - maxSequences: 1 (single sequence)
// - capacity: 2048 (context length)
// - maxBatch: 512
cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512)
fmt.Printf("✅ Cache initialized (type: %T)\n", cache)
}
// Test generation
prompt := "Question: What is machine learning? Answer:"
fmt.Printf("\nPrompt: %s\n", prompt)
tp := m.(model.TextProcessor)
tokens, err := tp.Encode(prompt, true)
if err != nil {
log.Fatal(err)
}
fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens))
// Generate 20 tokens
maxTokens := 20
generated := make([]int32, 0, maxTokens)
// Create sampler (temperature=0 for greedy sampling)
sampler := sample.NewSampler(0, 0, 0, 0, -1, nil)
for i := 0; i < maxTokens; i++ {
// Create a new context for each generation step to avoid memory buildup
ctx := m.Backend().NewContext()
var inputTokens []int32
var positions []int32
if i == 0 {
// First iteration: process all prompt tokens
inputTokens = tokens
positions = make([]int32, len(tokens))
for j := range positions {
positions[j] = int32(j)
}
} else {
// Subsequent iterations: only process the newly generated token
// The last token is at position len(tokens)-1 (its index in the sequence)
inputTokens = []int32{tokens[len(tokens)-1]}
positions = []int32{int32(len(tokens) - 1)}
}
sequences := make([]int, len(inputTokens))
// All tokens belong to sequence 0
inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens))
outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1)
batch := input.Batch{
Inputs: inputsTensor,
Positions: positions,
Sequences: sequences,
Outputs: outputs,
}
// Forward pass (model.Forward handles cache.StartForward internally)
logits, err := model.Forward(ctx, m, batch)
if err != nil {
ctx.Close()
log.Fatal(err)
}
logits = logits.Contiguous(ctx)
ctx.Forward(logits).Compute(logits)
logitValues := logits.Floats()
// Sample next token
nextToken, err := sampler.Sample(logitValues)
if err != nil {
ctx.Close()
log.Fatal(err)
}
// Close context before moving to next iteration
ctx.Close()
generated = append(generated, nextToken)
tokens = append(tokens, nextToken)
// Decode and print
decoded, _ := tp.Decode([]int32{nextToken})
fmt.Print(decoded)
// Stop on EOS
if nextToken == 2 || nextToken == 1 { // Common EOS tokens
break
}
}
fmt.Println("\n\n✅ Generation completed!")
fullText, _ := tp.Decode(generated)
fmt.Printf("Generated: %s\n", fullText)
}

View File

@ -10,9 +10,9 @@ import (
)
const (
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. "
olmo3NoFunctionsMessage = "You do not currently have access to any functions. "
olmo3WithFunctionsMessage = "You are provided with function signatures within <functions></functions> XML tags. You may call one or more functions to assist with the user query. Output any function calls within <function_calls></function_calls> XML tags. Do not make assumptions about what values to plug into functions."
)
type Olmo3Renderer struct{}
@ -145,4 +145,3 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api.
return sb.String(), nil
}

View File

@ -110,6 +110,7 @@ func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.Thi
if err != nil {
return "", err
}
slog.Debug("rendered prompt", "renderer", m.Config.Renderer, "prompt", rendered)
return rendered, nil
}