pr feedback

This commit is contained in:
Bruce MacDonald 2025-06-16 16:08:38 -07:00
parent 8d51d92f3b
commit 63e7634014
3 changed files with 238 additions and 130 deletions

View File

@ -6,6 +6,7 @@ import (
"os"
"slices"
"sync"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
@ -13,6 +14,12 @@ import (
"github.com/ollama/ollama/types/model"
)
// cacheEntry stores capabilities and the modification time of the model file
type cacheEntry struct {
capabilities []model.Capability
modTime time.Time
}
// ggufCapabilities is a cache for gguf model capabilities
var ggufCapabilities = &sync.Map{}
@ -59,12 +66,23 @@ func Capabilities(info ModelInfo) []model.Capability {
}
func ggufCapabilties(modelPath string) ([]model.Capability, error) {
if ggufCapabilities, ok := ggufCapabilities.Load(modelPath); ok {
capabilities := ggufCapabilities.([]model.Capability)
return capabilities, nil
// Get file info to check modification time
fileInfo, err := os.Stat(modelPath)
if err != nil {
return nil, err
}
currentModTime := fileInfo.ModTime()
// Check if we have a cached entry
if cached, ok := ggufCapabilities.Load(modelPath); ok {
entry := cached.(cacheEntry)
// If the file hasn't been modified since we cached it, return the cached capabilities
if entry.modTime.Equal(currentModTime) {
return entry.capabilities, nil
}
}
// If not cached, read the model file to determine capabilities
// If not cached or file was modified, read the model file to determine capabilities
capabilities := []model.Capability{}
r, err := os.Open(modelPath)
@ -87,8 +105,11 @@ func ggufCapabilties(modelPath string) ([]model.Capability, error) {
capabilities = append(capabilities, model.CapabilityVision)
}
// Cache the capabilities for future use
ggufCapabilities.Store(modelPath, capabilities)
// Cache the capabilities with the modification time
ggufCapabilities.Store(modelPath, cacheEntry{
capabilities: capabilities,
modTime: currentModTime,
})
return capabilities, nil
}

211
server/cache/capabilities_test.go vendored Normal file
View File

@ -0,0 +1,211 @@
package cache
import (
"bytes"
"maps"
"os"
"slices"
"testing"
"time"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/model"
)
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
func testGGUF(tb testing.TB, customKV ggml.KV) string {
tb.Helper()
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
if err != nil {
tb.Fatal(err)
}
defer f.Close()
kv := ggml.KV{}
maps.Copy(kv, customKV)
tensors := []*ggml.Tensor{
{
Name: "token_embd.weight",
Kind: 0,
Shape: []uint64{1, 1},
WriterTo: bytes.NewBuffer(make([]byte, 4)),
},
}
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
tb.Fatal(err)
}
return f.Name()
}
func TestCapabilities(t *testing.T) {
ggufCapabilities.Range(func(key, value any) bool {
ggufCapabilities.Delete(key)
return true
})
// Create test model paths
completionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
})
visionModelPath := testGGUF(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
})
embeddingModelPath := testGGUF(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
})
// Create templates
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testCases := []struct {
name string
model ModelInfo
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: ModelInfo{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability from gguf",
model: ModelInfo{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision capability from projector",
model: ModelInfo{
ModelPath: completionModelPath,
ProjectorPaths: []string{"/path/to/projector"},
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: ModelInfo{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: ModelInfo{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// First call - should read from file
caps := Capabilities(tc.model)
slices.Sort(caps)
slices.Sort(tc.expectedCaps)
if !slices.Equal(caps, tc.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
}
// Verify caching for models that read from GGUF
if tc.model.ModelPath != "" {
// Check that entry is cached
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
if !ok {
t.Error("Expected capabilities to be cached")
}
// Second call - should use cache
caps2 := Capabilities(tc.model)
slices.Sort(caps2)
if !slices.Equal(caps, caps2) {
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
}
}
})
}
// Test cache invalidation on file modification
t.Run("cache invalidation", func(t *testing.T) {
// Use completion model for this test
info := ModelInfo{
ModelPath: completionModelPath,
Template: chatTemplate,
}
// Get initial cached entry
cached, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Fatal("Expected model to be cached from previous tests")
}
entry := cached.(cacheEntry)
// Modify the file's timestamp to the future
future := time.Now().Add(time.Hour)
err := os.Chtimes(completionModelPath, future, future)
if err != nil {
t.Fatalf("Failed to update file timestamp: %v", err)
}
// Call should re-read from file due to changed modtime
caps := Capabilities(info)
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
}
// Check that cache was updated with new modtime
cached2, ok := ggufCapabilities.Load(completionModelPath)
if !ok {
t.Error("Expected capabilities to be cached after re-read")
}
entry2 := cached2.(cacheEntry)
if entry2.modTime.Equal(entry.modTime) {
t.Error("Expected cache entry to have updated modTime")
}
})
}

View File

@ -9,130 +9,6 @@ import (
"github.com/ollama/ollama/types/model"
)
func TestModelCapabilities(t *testing.T) {
// Create completion model (llama architecture without vision)
completionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
}, []*ggml.Tensor{})
// Create vision model (llama architecture with vision block count)
visionModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.vision.block_count": uint32(1),
}, []*ggml.Tensor{})
// Create embedding model (bert architecture with pooling type)
embeddingModelPath, _ := createBinFile(t, ggml.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(1),
}, []*ggml.Tensor{})
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
chatTemplate, err := template.Parse("{{ .prompt }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
if err != nil {
t.Fatalf("Failed to parse template: %v", err)
}
testModels := []struct {
name string
model Model
expectedCaps []model.Capability
}{
{
name: "model with completion capability",
model: Model{
ModelPath: completionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion},
},
{
name: "model with completion, tools, and insert capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with tools capability",
model: Model{
ModelPath: completionModelPath,
Template: toolsTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
},
{
name: "model with vision capability",
model: Model{
ModelPath: visionModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
},
{
name: "model with vision, tools, and insert capability",
model: Model{
ModelPath: visionModelPath,
Template: toolsInsertTemplate,
},
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
},
{
name: "model with embedding capability",
model: Model{
ModelPath: embeddingModelPath,
Template: chatTemplate,
},
expectedCaps: []model.Capability{model.CapabilityEmbedding},
},
}
// compare two slices of model.Capability regardless of order
compareCapabilities := func(a, b []model.Capability) bool {
if len(a) != len(b) {
return false
}
aCount := make(map[model.Capability]int)
for _, cap := range a {
aCount[cap]++
}
bCount := make(map[model.Capability]int)
for _, cap := range b {
bCount[cap]++
}
for cap, count := range aCount {
if bCount[cap] != count {
return false
}
}
return true
}
for _, tt := range testModels {
t.Run(tt.name, func(t *testing.T) {
// Test Capabilities method
caps := tt.model.Capabilities()
if !compareCapabilities(caps, tt.expectedCaps) {
t.Errorf("Expected capabilities %v, got %v", tt.expectedCaps, caps)
}
})
}
}
func TestModelCheckCapabilities(t *testing.T) {
// Create simple model file for tests that don't depend on GGUF content