212 lines
5.5 KiB
Go
212 lines
5.5 KiB
Go
package cache
|
|
|
|
import (
|
|
"bytes"
|
|
"maps"
|
|
"os"
|
|
"slices"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/fs/ggml"
|
|
"github.com/ollama/ollama/template"
|
|
"github.com/ollama/ollama/types/model"
|
|
)
|
|
|
|
// testGGUF creates a temporary GGUF model file for testing with custom key-value pairs
|
|
func testGGUF(tb testing.TB, customKV ggml.KV) string {
|
|
tb.Helper()
|
|
f, err := os.CreateTemp(tb.TempDir(), "test*.gguf")
|
|
if err != nil {
|
|
tb.Fatal(err)
|
|
}
|
|
defer f.Close()
|
|
|
|
kv := ggml.KV{}
|
|
maps.Copy(kv, customKV)
|
|
|
|
tensors := []*ggml.Tensor{
|
|
{
|
|
Name: "token_embd.weight",
|
|
Kind: 0,
|
|
Shape: []uint64{1, 1},
|
|
WriterTo: bytes.NewBuffer(make([]byte, 4)),
|
|
},
|
|
}
|
|
|
|
if err := ggml.WriteGGUF(f, kv, tensors); err != nil {
|
|
tb.Fatal(err)
|
|
}
|
|
|
|
return f.Name()
|
|
}
|
|
|
|
func TestCapabilities(t *testing.T) {
|
|
ggufCapabilities.Range(func(key, value any) bool {
|
|
ggufCapabilities.Delete(key)
|
|
return true
|
|
})
|
|
|
|
// Create test model paths
|
|
completionModelPath := testGGUF(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
})
|
|
|
|
visionModelPath := testGGUF(t, ggml.KV{
|
|
"general.architecture": "llama",
|
|
"llama.vision.block_count": uint32(1),
|
|
})
|
|
|
|
embeddingModelPath := testGGUF(t, ggml.KV{
|
|
"general.architecture": "bert",
|
|
"bert.pooling_type": uint32(1),
|
|
})
|
|
|
|
// Create templates
|
|
toolsInsertTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}{{ if .suffix }}{{ .suffix }}{{ end }}")
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse template: %v", err)
|
|
}
|
|
|
|
chatTemplate, err := template.Parse("{{ .prompt }}")
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse template: %v", err)
|
|
}
|
|
|
|
toolsTemplate, err := template.Parse("{{ .prompt }}{{ if .tools }}{{ .tools }}{{ end }}")
|
|
if err != nil {
|
|
t.Fatalf("Failed to parse template: %v", err)
|
|
}
|
|
|
|
testCases := []struct {
|
|
name string
|
|
model ModelInfo
|
|
expectedCaps []model.Capability
|
|
}{
|
|
{
|
|
name: "model with completion capability",
|
|
model: ModelInfo{
|
|
ModelPath: completionModelPath,
|
|
Template: chatTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion},
|
|
},
|
|
{
|
|
name: "model with completion, tools, and insert capability",
|
|
model: ModelInfo{
|
|
ModelPath: completionModelPath,
|
|
Template: toolsInsertTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools, model.CapabilityInsert},
|
|
},
|
|
{
|
|
name: "model with tools capability",
|
|
model: ModelInfo{
|
|
ModelPath: completionModelPath,
|
|
Template: toolsTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityTools},
|
|
},
|
|
{
|
|
name: "model with vision capability from gguf",
|
|
model: ModelInfo{
|
|
ModelPath: visionModelPath,
|
|
Template: chatTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
|
},
|
|
{
|
|
name: "model with vision capability from projector",
|
|
model: ModelInfo{
|
|
ModelPath: completionModelPath,
|
|
ProjectorPaths: []string{"/path/to/projector"},
|
|
Template: chatTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision},
|
|
},
|
|
{
|
|
name: "model with vision, tools, and insert capability",
|
|
model: ModelInfo{
|
|
ModelPath: visionModelPath,
|
|
Template: toolsInsertTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityCompletion, model.CapabilityVision, model.CapabilityTools, model.CapabilityInsert},
|
|
},
|
|
{
|
|
name: "model with embedding capability",
|
|
model: ModelInfo{
|
|
ModelPath: embeddingModelPath,
|
|
Template: chatTemplate,
|
|
},
|
|
expectedCaps: []model.Capability{model.CapabilityEmbedding},
|
|
},
|
|
}
|
|
|
|
for _, tc := range testCases {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
// First call - should read from file
|
|
caps := Capabilities(tc.model)
|
|
slices.Sort(caps)
|
|
slices.Sort(tc.expectedCaps)
|
|
if !slices.Equal(caps, tc.expectedCaps) {
|
|
t.Errorf("Expected capabilities %v, got %v", tc.expectedCaps, caps)
|
|
}
|
|
|
|
// Verify caching for models that read from GGUF
|
|
if tc.model.ModelPath != "" {
|
|
// Check that entry is cached
|
|
_, ok := ggufCapabilities.Load(tc.model.ModelPath)
|
|
if !ok {
|
|
t.Error("Expected capabilities to be cached")
|
|
}
|
|
|
|
// Second call - should use cache
|
|
caps2 := Capabilities(tc.model)
|
|
slices.Sort(caps2)
|
|
if !slices.Equal(caps, caps2) {
|
|
t.Errorf("Cached capabilities don't match original: expected %v, got %v", caps, caps2)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
// Test cache invalidation on file modification
|
|
t.Run("cache invalidation", func(t *testing.T) {
|
|
// Use completion model for this test
|
|
info := ModelInfo{
|
|
ModelPath: completionModelPath,
|
|
Template: chatTemplate,
|
|
}
|
|
|
|
// Get initial cached entry
|
|
cached, ok := ggufCapabilities.Load(completionModelPath)
|
|
if !ok {
|
|
t.Fatal("Expected model to be cached from previous tests")
|
|
}
|
|
entry := cached.(cacheEntry)
|
|
|
|
// Modify the file's timestamp to the future
|
|
future := time.Now().Add(time.Hour)
|
|
err := os.Chtimes(completionModelPath, future, future)
|
|
if err != nil {
|
|
t.Fatalf("Failed to update file timestamp: %v", err)
|
|
}
|
|
|
|
// Call should re-read from file due to changed modtime
|
|
caps := Capabilities(info)
|
|
if len(caps) != 1 || caps[0] != model.CapabilityCompletion {
|
|
t.Errorf("Expected [CapabilityCompletion], got %v", caps)
|
|
}
|
|
|
|
// Check that cache was updated with new modtime
|
|
cached2, ok := ggufCapabilities.Load(completionModelPath)
|
|
if !ok {
|
|
t.Error("Expected capabilities to be cached after re-read")
|
|
}
|
|
entry2 := cached2.(cacheEntry)
|
|
if entry2.modTime.Equal(entry.modTime) {
|
|
t.Error("Expected cache entry to have updated modTime")
|
|
}
|
|
})
|
|
}
|