server: model info caching system for improved performance
Implements an in-memory cache for loaded models with file modification time tracking to ensure cache validity. Models are now cached after first load and retrieved from cache on subsequent requests if the underlying manifest file hasn't changed. Key changes: - Add ModelCache with get/set methods and modification time validation - Cache models in GetModel() and check cache before disk load - Move capabilities calculation to model loading time and store in model - Update capability access to use cached field instead of runtime calculation - Add test coverage for cache behavior and model loading This reduces redundant model loading operations and improves response times for model access.
This commit is contained in:
parent
a6fbfc880c
commit
883f655dd6
|
|
@ -0,0 +1,95 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ModelCache struct {
|
||||
mu sync.RWMutex
|
||||
cache map[string]*CachedModel
|
||||
}
|
||||
|
||||
type CachedModel struct {
|
||||
model *Model
|
||||
modTime time.Time
|
||||
fileSize int64
|
||||
}
|
||||
|
||||
var modelCache = &ModelCache{
|
||||
cache: make(map[string]*CachedModel),
|
||||
}
|
||||
|
||||
func init() {
|
||||
modelCache.fill()
|
||||
}
|
||||
|
||||
func (c *ModelCache) fill() {
|
||||
manifests, err := Manifests(true) // continues on error
|
||||
if err != nil {
|
||||
slog.Warn("Failed to get manifests during cache fill", "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
for modelName := range manifests {
|
||||
nameStr := modelName.String()
|
||||
|
||||
// Load the model (this will populate the cache via GetModel -> set)
|
||||
_, err := GetModel(nameStr)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to load model during cache fill", "name", nameStr, "error", err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("Model cache filled")
|
||||
}
|
||||
|
||||
func (c *ModelCache) get(name string) (*Model, bool) {
|
||||
mp := ParseModelPath(name)
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Check manifest file modification time
|
||||
info, err := os.Stat(manifestPath)
|
||||
if err != nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
cached, exists := c.cache[name]
|
||||
if exists && cached.modTime.Equal(info.ModTime()) && cached.fileSize == info.Size() {
|
||||
// Cache hit - return cached model
|
||||
return cached.model, true
|
||||
}
|
||||
|
||||
// Cache miss or stale
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func (c *ModelCache) set(name string, model *Model) {
|
||||
mp := ParseModelPath(name)
|
||||
manifestPath, err := mp.GetManifestPath()
|
||||
if err != nil {
|
||||
slog.Debug("Failed to get manifest path for model", "name", name, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
info, err := os.Stat(manifestPath)
|
||||
if err != nil {
|
||||
slog.Debug("Failed to stat manifest file", "path", manifestPath, "error", err)
|
||||
return
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
c.cache[name] = &CachedModel{
|
||||
model: model,
|
||||
modTime: info.ModTime(),
|
||||
fileSize: info.Size(),
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,257 @@
|
|||
package server
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestModelCacheGet(t *testing.T) {
|
||||
testModel := &Model{
|
||||
Name: "test-model:latest",
|
||||
ShortName: "test-model",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
setupFunc func(t *testing.T, modelsDir string, cache *ModelCache) string // returns manifest path
|
||||
expectedModel *Model
|
||||
expectedExists bool
|
||||
}{
|
||||
{
|
||||
name: "cache hit - valid cached model",
|
||||
modelName: "test-model:latest",
|
||||
setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string {
|
||||
createTestModel(t, modelsDir, "test-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-abc123", Size: 1000},
|
||||
})
|
||||
|
||||
manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "test-model", "latest")
|
||||
info, err := os.Stat(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cache.cache["test-model:latest"] = &CachedModel{
|
||||
model: testModel,
|
||||
modTime: info.ModTime(),
|
||||
fileSize: info.Size(),
|
||||
}
|
||||
return manifestPath
|
||||
},
|
||||
expectedModel: testModel,
|
||||
expectedExists: true,
|
||||
},
|
||||
{
|
||||
name: "cache miss - no cached entry",
|
||||
modelName: "missing-model:latest",
|
||||
setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string {
|
||||
createTestModel(t, modelsDir, "missing-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-def456", Size: 2000},
|
||||
})
|
||||
return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "missing-model", "latest")
|
||||
},
|
||||
expectedModel: nil,
|
||||
expectedExists: false,
|
||||
},
|
||||
{
|
||||
name: "cache stale - modification time changed",
|
||||
modelName: "stale-model:latest",
|
||||
setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string {
|
||||
createTestModel(t, modelsDir, "stale-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-ghi789", Size: 3000},
|
||||
})
|
||||
|
||||
manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "stale-model", "latest")
|
||||
info, err := os.Stat(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cache.cache["stale-model:latest"] = &CachedModel{
|
||||
model: testModel,
|
||||
modTime: info.ModTime().Add(-time.Hour), // Stale time
|
||||
fileSize: info.Size(),
|
||||
}
|
||||
return manifestPath
|
||||
},
|
||||
expectedModel: nil,
|
||||
expectedExists: false,
|
||||
},
|
||||
{
|
||||
name: "cache stale - file size changed",
|
||||
modelName: "stale-size-model:latest",
|
||||
setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string {
|
||||
createTestModel(t, modelsDir, "stale-size-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-jkl012", Size: 4000},
|
||||
})
|
||||
|
||||
manifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "stale-size-model", "latest")
|
||||
info, err := os.Stat(manifestPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
cache.cache["stale-size-model:latest"] = &CachedModel{
|
||||
model: testModel,
|
||||
modTime: info.ModTime(),
|
||||
fileSize: info.Size() + 100, // Different size
|
||||
}
|
||||
return manifestPath
|
||||
},
|
||||
expectedModel: nil,
|
||||
expectedExists: false,
|
||||
},
|
||||
{
|
||||
name: "manifest file does not exist",
|
||||
modelName: "nonexistent-model:latest",
|
||||
setupFunc: func(t *testing.T, modelsDir string, cache *ModelCache) string {
|
||||
// Add to cache but don't create manifest file
|
||||
cache.cache["nonexistent-model:latest"] = &CachedModel{
|
||||
model: testModel,
|
||||
modTime: time.Now(),
|
||||
fileSize: 100,
|
||||
}
|
||||
return "" // No manifest created
|
||||
},
|
||||
expectedModel: nil,
|
||||
expectedExists: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
// Create fresh cache instance for each test
|
||||
cache := &ModelCache{
|
||||
cache: make(map[string]*CachedModel),
|
||||
}
|
||||
|
||||
if tt.setupFunc != nil {
|
||||
tt.setupFunc(t, modelsDir, cache)
|
||||
}
|
||||
|
||||
model, exists := cache.get(tt.modelName)
|
||||
|
||||
if exists != tt.expectedExists {
|
||||
t.Errorf("get() exists = %v, want %v", exists, tt.expectedExists)
|
||||
}
|
||||
|
||||
if tt.expectedModel != nil && model != tt.expectedModel {
|
||||
t.Errorf("get() model = %v, want %v", model, tt.expectedModel)
|
||||
}
|
||||
|
||||
if tt.expectedModel == nil && model != nil {
|
||||
t.Errorf("get() model = %v, want nil", model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelCacheSet(t *testing.T) {
|
||||
testModel := &Model{
|
||||
Name: "test-model:latest",
|
||||
ShortName: "test-model",
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
modelName string
|
||||
model *Model
|
||||
setupFunc func(t *testing.T, modelsDir string) string // returns manifest path
|
||||
expectCached bool
|
||||
}{
|
||||
{
|
||||
name: "successful cache set",
|
||||
modelName: "test-model:latest",
|
||||
model: testModel,
|
||||
setupFunc: func(t *testing.T, modelsDir string) string {
|
||||
createTestModel(t, modelsDir, "test-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-abc123", Size: 1000},
|
||||
})
|
||||
return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "test-model", "latest")
|
||||
},
|
||||
expectCached: true,
|
||||
},
|
||||
{
|
||||
name: "manifest file does not exist",
|
||||
modelName: "nonexistent-model:latest",
|
||||
model: testModel,
|
||||
setupFunc: func(t *testing.T, modelsDir string) string {
|
||||
// Don't create manifest file
|
||||
return ""
|
||||
},
|
||||
expectCached: false,
|
||||
},
|
||||
{
|
||||
name: "overwrite existing cache entry",
|
||||
modelName: "existing-model:latest",
|
||||
model: testModel,
|
||||
setupFunc: func(t *testing.T, modelsDir string) string {
|
||||
createTestModel(t, modelsDir, "existing-model", []Layer{
|
||||
{MediaType: "application/vnd.ollama.image.model", Digest: "sha256-def456", Size: 2000},
|
||||
})
|
||||
return filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", "existing-model", "latest")
|
||||
},
|
||||
expectCached: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
// Create fresh cache instance for each test
|
||||
cache := &ModelCache{
|
||||
cache: make(map[string]*CachedModel),
|
||||
}
|
||||
|
||||
if tt.name == "overwrite existing cache entry" {
|
||||
cache.cache[tt.modelName] = &CachedModel{
|
||||
model: &Model{Name: "old-model"},
|
||||
modTime: time.Now().Add(-time.Hour),
|
||||
fileSize: 50,
|
||||
}
|
||||
}
|
||||
|
||||
if tt.setupFunc != nil {
|
||||
tt.setupFunc(t, modelsDir)
|
||||
}
|
||||
|
||||
cache.set(tt.modelName, tt.model)
|
||||
|
||||
cached, exists := cache.cache[tt.modelName]
|
||||
|
||||
if tt.expectCached {
|
||||
if !exists {
|
||||
t.Errorf("set() expected model to be cached, but it wasn't")
|
||||
return
|
||||
}
|
||||
|
||||
if cached.model != tt.model {
|
||||
t.Errorf("set() cached model = %v, want %v", cached.model, tt.model)
|
||||
}
|
||||
|
||||
// Verify file info is captured correctly if manifest exists
|
||||
expectedManifestPath := filepath.Join(modelsDir, "manifests", "registry.ollama.ai", "library", tt.modelName[:len(tt.modelName)-7], "latest")
|
||||
if info, err := os.Stat(expectedManifestPath); err == nil {
|
||||
if !cached.modTime.Equal(info.ModTime()) {
|
||||
t.Errorf("set() cached modTime = %v, want %v", cached.modTime, info.ModTime())
|
||||
}
|
||||
if cached.fileSize != info.Size() {
|
||||
t.Errorf("set() cached fileSize = %v, want %v", cached.fileSize, info.Size())
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if exists {
|
||||
t.Errorf("set() expected model not to be cached, but it was")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
@ -65,11 +65,12 @@ type Model struct {
|
|||
Options map[string]any
|
||||
Messages []api.Message
|
||||
|
||||
Template *template.Template
|
||||
Capabilities []model.Capability
|
||||
Template *template.Template
|
||||
}
|
||||
|
||||
// Capabilities returns the capabilities that the model supports
|
||||
func (m *Model) Capabilities() []model.Capability {
|
||||
func Capabilities(m *Model) []model.Capability {
|
||||
capabilities := []model.Capability{}
|
||||
|
||||
// Check for completion capability
|
||||
|
|
@ -121,7 +122,6 @@ func (m *Model) Capabilities() []model.Capability {
|
|||
// CheckCapabilities checks if the model has the specified capabilities returning an error describing
|
||||
// any missing or unknown capabilities
|
||||
func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
||||
available := m.Capabilities()
|
||||
var errs []error
|
||||
|
||||
// Map capabilities to their corresponding error
|
||||
|
|
@ -141,7 +141,7 @@ func (m *Model) CheckCapabilities(want ...model.Capability) error {
|
|||
return fmt.Errorf("unknown capability: %s", cap)
|
||||
}
|
||||
|
||||
if !slices.Contains(available, cap) {
|
||||
if !slices.Contains(m.Capabilities, cap) {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -272,6 +272,12 @@ func GetManifest(mp ModelPath) (*Manifest, string, error) {
|
|||
}
|
||||
|
||||
func GetModel(name string) (*Model, error) {
|
||||
// Try cache first
|
||||
if model, hit := modelCache.get(name); hit {
|
||||
return model, nil
|
||||
}
|
||||
|
||||
// Cache miss, load the model from disk
|
||||
mp := ParseModelPath(name)
|
||||
manifest, digest, err := GetManifest(mp)
|
||||
if err != nil {
|
||||
|
|
@ -368,6 +374,10 @@ func GetModel(name string) (*Model, error) {
|
|||
}
|
||||
}
|
||||
|
||||
model.Capabilities = Capabilities(model)
|
||||
|
||||
modelCache.set(name, model)
|
||||
|
||||
return model, nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -824,7 +824,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
|||
Template: m.Template.String(),
|
||||
Details: modelDetails,
|
||||
Messages: msgs,
|
||||
Capabilities: m.Capabilities(),
|
||||
Capabilities: m.Capabilities,
|
||||
ModifiedAt: manifest.fi.ModTime(),
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue