From 63e7634014c0f12536dc77583d2dc9f36e499024 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Mon, 16 Jun 2025 16:08:38 -0700 Subject: [PATCH] pr feedback --- server/cache/capabilities.go | 33 ++++- server/cache/capabilities_test.go | 211 ++++++++++++++++++++++++++++++ server/images_test.go | 124 ------------------ 3 files changed, 238 insertions(+), 130 deletions(-) create mode 100644 server/cache/capabilities_test.go diff --git a/server/cache/capabilities.go b/server/cache/capabilities.go index 354944319..cf1c66665 100644 --- a/server/cache/capabilities.go +++ b/server/cache/capabilities.go @@ -6,6 +6,7 @@ import ( "os" "slices" "sync" + "time" "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/template" @@ -13,6 +14,12 @@ import ( "github.com/ollama/ollama/types/model" ) +// cacheEntry stores capabilities and the modification time of the model file +type cacheEntry struct { + capabilities []model.Capability + modTime time.Time +} + // ggufCapabilities is a cache for gguf model capabilities var ggufCapabilities = &sync.Map{} @@ -59,12 +66,23 @@ func Capabilities(info ModelInfo) []model.Capability { } func ggufCapabilties(modelPath string) ([]model.Capability, error) { - if ggufCapabilities, ok := ggufCapabilities.Load(modelPath); ok { - capabilities := ggufCapabilities.([]model.Capability) - return capabilities, nil + // Get file info to check modification time + fileInfo, err := os.Stat(modelPath) + if err != nil { + return nil, err + } + currentModTime := fileInfo.ModTime() + + // Check if we have a cached entry + if cached, ok := ggufCapabilities.Load(modelPath); ok { + entry := cached.(cacheEntry) + // If the file hasn't been modified since we cached it, return the cached capabilities + if entry.modTime.Equal(currentModTime) { + return entry.capabilities, nil + } } - // If not cached, read the model file to determine capabilities + // If not cached or file was modified, read the model file to determine capabilities capabilities := []model.Capability{} r, err := os.Open(modelPath) @@ -87,8 +105,11 @@ func ggufCapabilties(modelPath string) ([]model.Capability, error) { capabilities = append(capabilities, model.CapabilityVision) } - // Cache the capabilities for future use - ggufCapabilities.Store(modelPath, capabilities) + // Cache the capabilities with the modification time + ggufCapabilities.Store(modelPath, cacheEntry{ + capabilities: capabilities, + modTime: currentModTime, + }) return capabilities, nil } diff --git a/server/cache/capabilities_test.go b/server/cache/capabilities_test.go new file mode 100644 index 000000000..4d23961e9 --- /dev/null +++ b/server/cache/capabilities_test.go @@ -0,0 +1,211 @@ +package cache + +import ( + "bytes" + "maps" + "os" + "slices" + "testing" + "time" + + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/template" + "github.com/ollama/ollama/types/model" +) + +// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs +func testGGUF(tb testing.TB, customKV ggml.KV) string { + tb.Helper() + f, err := os.CreateTemp(tb.TempDir(), "test*.gguf") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + kv := ggml.KV{} + maps.Copy(kv, customKV) + + tensors := []*ggml.Tensor{ + { + Name: "token_embd.weight", + Kind: 0, + Shape: []uint64{1, 1}, + WriterTo: bytes.NewBuffer(make([]byte, 4)), + }, + } + + if err := ggml.WriteGGUF(f, kv, tensors); err != nil { + tb.Fatal(err) + } + + return f.Name() +} + +func TestCapabilities(t *testing.T) { + ggufCapabilities.Range(func(key, value any) bool { + ggufCapabilities.Delete(key) + return true + }) + + // Create test model paths + completionModelPath := testGGUF(t, ggml.KV{ + "general.architecture": "llama", + }) + + visionModelPath := testGGUF(t, ggml.KV{ + "general.architecture": "llama", + "llama.vision.block_count": uint32(1), + }) + + embeddingModelPath := testGGUF(t, ggml.KV{ + "general.architecture": "bert", + "bert.pooling_type": uint32(1), + }) + + // Create templates + toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + chatTemplate, err := template.Parse("{{ .prompt }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") + if err != nil { + t.Fatalf("Failed to parse template: %v", err) + } + + testCases := []struct { + name string + model ModelInfo + expectedCaps []model.Capability + }{ + { + name: "model with completion capability", + model: ModelInfo{ + ModelPath: completionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion}, + }, + { + name: "model with completion, tools, and insert capability", + model: ModelInfo{ + ModelPath: completionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with tools capability", + model: ModelInfo{ + ModelPath: completionModelPath, + Template: toolsTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools}, + }, + { + name: "model with vision capability from gguf", + model: ModelInfo{ + ModelPath: visionModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision}, + }, + { + name: "model with vision capability from projector", + model: ModelInfo{ + ModelPath: completionModelPath, + ProjectorPaths: []string{"/path/to/projector"}, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision}, + }, + { + name: "model with vision, tools, and insert capability", + model: ModelInfo{ + ModelPath: visionModelPath, + Template: toolsInsertTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert}, + }, + { + name: "model with embedding capability", + model: ModelInfo{ + ModelPath: embeddingModelPath, + Template: chatTemplate, + }, + expectedCaps: []model.Capability{model.CapabilityEmbedding}, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // First call - should read from file + caps := Capabilities(tc.model) + slices.Sort(caps) + slices.Sort(tc.expectedCaps) + if !slices.Equal(caps, tc.expectedCaps) { + t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps) + } + + // Verify caching for models that read from GGUF + if tc.model.ModelPath != "" { + // Check that entry is cached + _, ok := ggufCapabilities.Load(tc.model.ModelPath) + if !ok { + t.Error("Expected capabilities to be cached") + } + + // Second call - should use cache + caps2 := Capabilities(tc.model) + slices.Sort(caps2) + if !slices.Equal(caps, caps2) { + t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2) + } + } + }) + } + + // Test cache invalidation on file modification + t.Run("cache invalidation", func(t *testing.T) { + // Use completion model for this test + info := ModelInfo{ + ModelPath: completionModelPath, + Template: chatTemplate, + } + + // Get initial cached entry + cached, ok := ggufCapabilities.Load(completionModelPath) + if !ok { + t.Fatal("Expected model to be cached from previous tests") + } + entry := cached.(cacheEntry) + + // Modify the file's timestamp to the future + future := time.Now().Add(time.Hour) + err := os.Chtimes(completionModelPath, future, future) + if err != nil { + t.Fatalf("Failed to update file timestamp: %v", err) + } + + // Call should re-read from file due to changed modtime + caps := Capabilities(info) + if len(caps) != 1 || caps[0] != model.CapabilityCompletion { + t.Errorf("Expected [CapabilityCompletion], got %v", caps) + } + + // Check that cache was updated with new modtime + cached2, ok := ggufCapabilities.Load(completionModelPath) + if !ok { + t.Error("Expected capabilities to be cached after re-read") + } + entry2 := cached2.(cacheEntry) + if entry2.modTime.Equal(entry.modTime) { + t.Error("Expected cache entry to have updated modTime") + } + }) +} diff --git a/server/images_test.go b/server/images_test.go index a2fba8d98..4cacf38d9 100644 --- a/server/images_test.go +++ b/server/images_test.go @@ -9,130 +9,6 @@ import ( "github.com/ollama/ollama/types/model" ) -func TestModelCapabilities(t *testing.T) { - // Create completion model (llama architecture without vision) - completionModelPath, _ := createBinFile(t, ggml.KV{ - "general.architecture": "llama", - }, []*ggml.Tensor{}) - - // Create vision model (llama architecture with vision block count) - visionModelPath, _ := createBinFile(t, ggml.KV{ - "general.architecture": "llama", - "llama.vision.block_count": uint32(1), - }, []*ggml.Tensor{}) - - // Create embedding model (bert architecture with pooling type) - embeddingModelPath, _ := createBinFile(t, ggml.KV{ - "general.architecture": "bert", - "bert.pooling_type": uint32(1), - }, []*ggml.Tensor{}) - - toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}") - if err != nil { - t.Fatalf("Failed to parse template: %v", err) - } - - chatTemplate, err := template.Parse("{{ .prompt }}") - if err != nil { - t.Fatalf("Failed to parse template: %v", err) - } - - toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}") - if err != nil { - t.Fatalf("Failed to parse template: %v", err) - } - - testModels := []struct { - name string - model Model - expectedCaps []model.Capability - }{ - { - name: "model with completion capability", - model: Model{ - ModelPath: completionModelPath, - Template: chatTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityCompletion}, - }, - - { - name: "model with completion, tools, and insert capability", - model: Model{ - ModelPath: completionModelPath, - Template: toolsInsertTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert}, - }, - { - name: "model with tools capability", - model: Model{ - ModelPath: completionModelPath, - Template: toolsTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools}, - }, - { - name: "model with vision capability", - model: Model{ - ModelPath: visionModelPath, - Template: chatTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision}, - }, - { - name: "model with vision, tools, and insert capability", - model: Model{ - ModelPath: visionModelPath, - Template: toolsInsertTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert}, - }, - { - name: "model with embedding capability", - model: Model{ - ModelPath: embeddingModelPath, - Template: chatTemplate, - }, - expectedCaps: []model.Capability{model.CapabilityEmbedding}, - }, - } - - // compare two slices of model.Capability regardless of order - compareCapabilities := func(a, b []model.Capability) bool { - if len(a) != len(b) { - return false - } - - aCount := make(map[model.Capability]int) - for _, cap := range a { - aCount[cap]++ - } - - bCount := make(map[model.Capability]int) - for _, cap := range b { - bCount[cap]++ - } - - for cap, count := range aCount { - if bCount[cap] != count { - return false - } - } - - return true - } - - for _, tt := range testModels { - t.Run(tt.name, func(t *testing.T) { - // Test Capabilities method - caps := tt.model.Capabilities() - if !compareCapabilities(caps, tt.expectedCaps) { - t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps) - } - }) - } -} func TestModelCheckCapabilities(t *testing.T) { // Create simple model file for tests that don't depend on GGUF content