diff --git a/model/models/olmo/model.go b/model/models/olmo/model.go index 13ed0cc59..231cbf7f8 100644 --- a/model/models/olmo/model.go +++ b/model/models/olmo/model.go @@ -2,6 +2,7 @@ package olmo import ( "cmp" + "fmt" "math" "github.com/ollama/ollama/fs" @@ -37,6 +38,7 @@ type Model struct { } func New(c fs.Config) (model.Model, error) { + fmt.Println("šŸ¦™ OLMo model loaded!") vocabulary := model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), Scores: c.Floats("tokenizer.ggml.scores"), diff --git a/model/models/olmo/testolmo.go b/model/models/olmo/testolmo.go new file mode 100644 index 000000000..d2ced0afc --- /dev/null +++ b/model/models/olmo/testolmo.go @@ -0,0 +1,132 @@ +package olmo + +import ( + "context" + "fmt" + "log" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" + "github.com/ollama/ollama/sample" +) + +func main() { + modelPath := "/Users/nicole/models/Olmo-3-7B-Think/olmo-3-7b-think-q8_0.gguf" + + fmt.Println("Loading OLMo model...") + m, err := model.New(modelPath, ml.BackendParams{AllocMemory: true}) + if err != nil { + log.Fatal(err) + } + + if err := m.Backend().Load(context.Background(), func(f float32) {}); err != nil { + log.Fatal(err) + } + + fmt.Println("āœ… Model loaded successfully!") + + // Initialize the cache + cache := m.Config().Cache + if cache != nil { + // Initialize with reasonable defaults: + // - dtype: F16 + // - maxSequences: 1 (single sequence) + // - capacity: 2048 (context length) + // - maxBatch: 512 + cache.Init(m.Backend(), ml.DTypeF16, 1, 2048, 512) + fmt.Printf("āœ… Cache initialized (type: %T)\n", cache) + } + + // Test generation + prompt := "Question: What is machine learning? Answer:" + fmt.Printf("\nPrompt: %s\n", prompt) + + tp := m.(model.TextProcessor) + tokens, err := tp.Encode(prompt, true) + if err != nil { + log.Fatal(err) + } + + fmt.Printf("Tokens: %v (count: %d)\n", tokens, len(tokens)) + + // Generate 20 tokens + maxTokens := 20 + generated := make([]int32, 0, maxTokens) + + // Create sampler (temperature=0 for greedy sampling) + sampler := sample.NewSampler(0, 0, 0, 0, -1, nil) + + for i := 0; i < maxTokens; i++ { + // Create a new context for each generation step to avoid memory buildup + ctx := m.Backend().NewContext() + + var inputTokens []int32 + var positions []int32 + + if i == 0 { + // First iteration: process all prompt tokens + inputTokens = tokens + positions = make([]int32, len(tokens)) + for j := range positions { + positions[j] = int32(j) + } + } else { + // Subsequent iterations: only process the newly generated token + // The last token is at position len(tokens)-1 (its index in the sequence) + inputTokens = []int32{tokens[len(tokens)-1]} + positions = []int32{int32(len(tokens) - 1)} + } + + sequences := make([]int, len(inputTokens)) + // All tokens belong to sequence 0 + + inputsTensor := ctx.Input().FromInts(inputTokens, len(inputTokens)) + outputs := ctx.Input().FromInts([]int32{int32(len(inputTokens) - 1)}, 1) + + batch := input.Batch{ + Inputs: inputsTensor, + Positions: positions, + Sequences: sequences, + Outputs: outputs, + } + + // Forward pass (model.Forward handles cache.StartForward internally) + logits, err := model.Forward(ctx, m, batch) + if err != nil { + ctx.Close() + log.Fatal(err) + } + + logits = logits.Contiguous(ctx) + ctx.Forward(logits).Compute(logits) + + logitValues := logits.Floats() + + // Sample next token + nextToken, err := sampler.Sample(logitValues) + if err != nil { + ctx.Close() + log.Fatal(err) + } + + // Close context before moving to next iteration + ctx.Close() + + generated = append(generated, nextToken) + tokens = append(tokens, nextToken) + + // Decode and print + decoded, _ := tp.Decode([]int32{nextToken}) + fmt.Print(decoded) + + // Stop on EOS + if nextToken == 2 || nextToken == 1 { // Common EOS tokens + break + } + } + + fmt.Println("\n\nāœ… Generation completed!") + fullText, _ := tp.Decode(generated) + fmt.Printf("Generated: %s\n", fullText) +} diff --git a/model/renderers/olmo3.go b/model/renderers/olmo3.go index 4a5f59fd5..24ade20dc 100644 --- a/model/renderers/olmo3.go +++ b/model/renderers/olmo3.go @@ -10,9 +10,9 @@ import ( ) const ( - olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. " - olmo3NoFunctionsMessage = "You do not currently have access to any functions. " - olmo3WithFunctionsMessage = "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." + olmo3DefaultSystemMessage = "You are a helpful function-calling AI assistant. " + olmo3NoFunctionsMessage = "You do not currently have access to any functions. " + olmo3WithFunctionsMessage = "You are provided with function signatures within XML tags. You may call one or more functions to assist with the user query. Output any function calls within XML tags. Do not make assumptions about what values to plug into functions." ) type Olmo3Renderer struct{} @@ -145,4 +145,3 @@ func (r *Olmo3Renderer) Render(messages []api.Message, tools []api.Tool, _ *api. return sb.String(), nil } - diff --git a/server/prompt.go b/server/prompt.go index 217591982..a7120d3a8 100644 --- a/server/prompt.go +++ b/server/prompt.go @@ -110,6 +110,7 @@ func renderPrompt(m *Model, msgs []api.Message, tools []api.Tool, think *api.Thi if err != nil { return "", err } + slog.Debug("rendered prompt", "renderer", m.Config.Renderer, "prompt", rendered) return rendered, nil }