From 883f655dd6e33d923026606902d8b81eccd0fc71 Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Tue, 10 Jun 2025 16:10:13 -0700 Subject: [PATCH] server: model info caching system for improved performance Implements an in-memory cache for loaded models with file modification time tracking to ensure cache validity. Models are now cached after first load and retrieved from cache on subsequent requests if the underlying manifest file hasn't changed. Key changes: - Add ModelCache with get/set methods and modification time validation - Cache models in GetModel() and check cache before disk load - Move capabilities calculation to model loading time and store in model - Update capability access to use cached field instead of runtime calculation - Add test coverage for cache behavior and model loading This reduces redundant model loading operations and improves response times for model access. --- server/cache.go | 95 ++++++++++++++++ server/cache_test.go | 257 +++++++++++++++++++++++++++++++++++++++++++ server/images.go | 18 ++- server/routes.go | 2 +- 4 files changed, 367 insertions(+), 5 deletions(-) create mode 100644 server/cache.go create mode 100644 server/cache_test.go diff --git a/server/cache.go b/server/cache.go new file mode 100644 index 000000000..e6ac7ea01 --- /dev/null +++ b/server/cache.go @@ -0,0 +1,95 @@ +package server + +import ( + "log/slog" + "os" + "sync" + "time" +) + +type ModelCache struct { + mu sync.RWMutex + cache map[string]*CachedModel +} + +type CachedModel struct { + model *Model + modTime time.Time + fileSize int64 +} + +var modelCache = &ModelCache{ + cache: make(map[string]*CachedModel), +} + +func init() { + modelCache.fill() +} + +func (c *ModelCache) fill() { + manifests, err := Manifests(true) // continues on error + if err != nil { + slog.Warn("Failed to get manifests during cache fill", "error", err) + return + } + + for modelName := range manifests { + nameStr := modelName.String() + + // Load the model (this will populate the cache via GetModel -> set) + _, err := GetModel(nameStr) + if err != nil { + slog.Debug("Failed to load model during cache fill", "name", nameStr, "error", err) + continue + } + } + + slog.Debug("Model cache filled") +} + +func (c *ModelCache) get(name string) (*Model, bool) { + mp := ParseModelPath(name) + manifestPath, err := mp.GetManifestPath() + if err != nil { + return nil, false + } + + // Check manifest file modification time + info, err := os.Stat(manifestPath) + if err != nil { + return nil, false + } + + cached, exists := c.cache[name] + if exists && cached.modTime.Equal(info.ModTime()) && cached.fileSize == info.Size() { + // Cache hit - return cached model + return cached.model, true + } + + // Cache miss or stale + return nil, false +} + +func (c *ModelCache) set(name string, model *Model) { + mp := ParseModelPath(name) + manifestPath, err := mp.GetManifestPath() + if err != nil { + slog.Debug("Failed to get manifest path for model", "name", name, "error", err) + return + } + + info, err := os.Stat(manifestPath) + if err != nil { + slog.Debug("Failed to stat manifest file", "path", manifestPath, "error", err) + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + c.cache[name] = &CachedModel{ + model: model, + modTime: info.ModTime(), + fileSize: info.Size(), + } +} diff --git a/server/cache_test.go b/server/cache_test.go new file mode 100644 index 000000000..3e76544c2 --- /dev/null +++ b/server/cache_test.go @@ -0,0 +1,257 @@ +package server + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestModelCacheGet(t *testing.T) { + testModel := &Model{ + Name: "test-model:latest", + ShortName: "test-model", + } + + tests := []struct { + name string + modelName string + setupFunc func(t *testing.T, modelsDir string, cache *ModelCache) string // returns manifest path + expectedModel *Model + expectedExists bool + }{ + { + name: "cache hit - valid cached model", + modelName: "test-model:latest", + setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string { + createTestModel(t, modelsDir, "test-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-abc123", Size: 1000}, + }) + + manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "test-model", "latest") + info, err := os.Stat(manifestPath) + if err != nil { + t.Fatal(err) + } + + cache.cache["test-model:latest"] = &CachedModel{ + model: testModel, + modTime: info.ModTime(), + fileSize: info.Size(), + } + return manifestPath + }, + expectedModel: testModel, + expectedExists: true, + }, + { + name: "cache miss - no cached entry", + modelName: "missing-model:latest", + setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string { + createTestModel(t, modelsDir, "missing-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-def456", Size: 2000}, + }) + return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "missing-model", "latest") + }, + expectedModel: nil, + expectedExists: false, + }, + { + name: "cache stale - modification time changed", + modelName: "stale-model:latest", + setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string { + createTestModel(t, modelsDir, "stale-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-ghi789", Size: 3000}, + }) + + manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "stale-model", "latest") + info, err := os.Stat(manifestPath) + if err != nil { + t.Fatal(err) + } + + cache.cache["stale-model:latest"] = &CachedModel{ + model: testModel, + modTime: info.ModTime().Add(-time.Hour), // Stale time + fileSize: info.Size(), + } + return manifestPath + }, + expectedModel: nil, + expectedExists: false, + }, + { + name: "cache stale - file size changed", + modelName: "stale-size-model:latest", + setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string { + createTestModel(t, modelsDir, "stale-size-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-jkl012", Size: 4000}, + }) + + manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "stale-size-model", "latest") + info, err := os.Stat(manifestPath) + if err != nil { + t.Fatal(err) + } + + cache.cache["stale-size-model:latest"] = &CachedModel{ + model: testModel, + modTime: info.ModTime(), + fileSize: info.Size() + 100, // Different size + } + return manifestPath + }, + expectedModel: nil, + expectedExists: false, + }, + { + name: "manifest file does not exist", + modelName: "nonexistent-model:latest", + setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string { + // Add to cache but don't create manifest file + cache.cache["nonexistent-model:latest"] = &CachedModel{ + model: testModel, + modTime: time.Now(), + fileSize: 100, + } + return "" // No manifest created + }, + expectedModel: nil, + expectedExists: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + modelsDir := t.TempDir() + t.Setenv("OLLAMA_MODELS", modelsDir) + + // Create fresh cache instance for each test + cache := &ModelCache{ + cache: make(map[string]*CachedModel), + } + + if tt.setupFunc != nil { + tt.setupFunc(t, modelsDir, cache) + } + + model, exists := cache.get(tt.modelName) + + if exists != tt.expectedExists { + t.Errorf("get() exists = %v, want %v", exists, tt.expectedExists) + } + + if tt.expectedModel != nil && model != tt.expectedModel { + t.Errorf("get() model = %v, want %v", model, tt.expectedModel) + } + + if tt.expectedModel == nil && model != nil { + t.Errorf("get() model = %v, want nil", model) + } + }) + } +} + +func TestModelCacheSet(t *testing.T) { + testModel := &Model{ + Name: "test-model:latest", + ShortName: "test-model", + } + + tests := []struct { + name string + modelName string + model *Model + setupFunc func(t *testing.T, modelsDir string) string // returns manifest path + expectCached bool + }{ + { + name: "successful cache set", + modelName: "test-model:latest", + model: testModel, + setupFunc: func(t *testing.T, modelsDir string) string { + createTestModel(t, modelsDir, "test-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-abc123", Size: 1000}, + }) + return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "test-model", "latest") + }, + expectCached: true, + }, + { + name: "manifest file does not exist", + modelName: "nonexistent-model:latest", + model: testModel, + setupFunc: func(t *testing.T, modelsDir string) string { + // Don't create manifest file + return "" + }, + expectCached: false, + }, + { + name: "overwrite existing cache entry", + modelName: "existing-model:latest", + model: testModel, + setupFunc: func(t *testing.T, modelsDir string) string { + createTestModel(t, modelsDir, "existing-model", []Layer{ + {MediaType: "application/vnd.ollama.image.model", Digest: "sha256-def456", Size: 2000}, + }) + return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "existing-model", "latest") + }, + expectCached: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + modelsDir := t.TempDir() + t.Setenv("OLLAMA_MODELS", modelsDir) + + // Create fresh cache instance for each test + cache := &ModelCache{ + cache: make(map[string]*CachedModel), + } + + if tt.name == "overwrite existing cache entry" { + cache.cache[tt.modelName] = &CachedModel{ + model: &Model{Name: "old-model"}, + modTime: time.Now().Add(-time.Hour), + fileSize: 50, + } + } + + if tt.setupFunc != nil { + tt.setupFunc(t, modelsDir) + } + + cache.set(tt.modelName, tt.model) + + cached, exists := cache.cache[tt.modelName] + + if tt.expectCached { + if !exists { + t.Errorf("set() expected model to be cached, but it wasn't") + return + } + + if cached.model != tt.model { + t.Errorf("set() cached model = %v, want %v", cached.model, tt.model) + } + + // Verify file info is captured correctly if manifest exists + expectedManifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", tt.modelName[:len(tt.modelName)-7], "latest") + if info, err := os.Stat(expectedManifestPath); err == nil { + if !cached.modTime.Equal(info.ModTime()) { + t.Errorf("set() cached modTime = %v, want %v", cached.modTime, info.ModTime()) + } + if cached.fileSize != info.Size() { + t.Errorf("set() cached fileSize = %v, want %v", cached.fileSize, info.Size()) + } + } + } else { + if exists { + t.Errorf("set() expected model not to be cached, but it was") + } + } + }) + } +} diff --git a/server/images.go b/server/images.go index 38505cc51..b1ec3acd2 100644 --- a/server/images.go +++ b/server/images.go @@ -65,11 +65,12 @@ type Model struct { Options map[string]any Messages []api.Message - Template *template.Template + Capabilities []model.Capability + Template *template.Template } // Capabilities returns the capabilities that the model supports -func (m *Model) Capabilities() []model.Capability { +func Capabilities(m *Model) []model.Capability { capabilities := []model.Capability{} // Check for completion capability @@ -121,7 +122,6 @@ func (m *Model) Capabilities() []model.Capability { // CheckCapabilities checks if the model has the specified capabilities returning an error describing // any missing or unknown capabilities func (m *Model) CheckCapabilities(want ...model.Capability) error { - available := m.Capabilities() var errs []error // Map capabilities to their corresponding error @@ -141,7 +141,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error { return fmt.Errorf("unknown capability: %s", cap) } - if !slices.Contains(available, cap) { + if !slices.Contains(m.Capabilities, cap) { errs = append(errs, err) } } @@ -272,6 +272,12 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) { } func GetModel(name string) (*Model, error) { + // Try cache first + if model, hit := modelCache.get(name); hit { + return model, nil + } + + // Cache miss, load the model from disk mp := ParseModelPath(name) manifest, digest, err := GetManifest(mp) if err != nil { @@ -368,6 +374,10 @@ func GetModel(name string) (*Model, error) { } } + model.Capabilities = Capabilities(model) + + modelCache.set(name, model) + return model, nil } diff --git a/server/routes.go b/server/routes.go index cb46cef11..2bb99f037 100644 --- a/server/routes.go +++ b/server/routes.go @@ -824,7 +824,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) { Template: m.Template.String(), Details: modelDetails, Messages: msgs, - Capabilities: m.Capabilities(), + Capabilities: m.Capabilities, ModifiedAt: manifest.fi.ModTime(), }