new runner
This commit is contained in:
246
runner/oldrunner/cache.go
Normal file
246
runner/oldrunner/cache.go
Normal file
@@ -0,0 +1,246 @@
|
||||
package oldrunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
type InputCache struct {
|
||||
// context window size (per slot)
|
||||
numCtx int
|
||||
|
||||
// individual KV caches
|
||||
slots []InputCacheSlot
|
||||
|
||||
// optimize cache eviction for multiple users
|
||||
multiUserCache bool
|
||||
|
||||
lc *llama.Context
|
||||
}
|
||||
|
||||
func NewInputCache(lc *llama.Context, kvSize int, numSlots int, multiUserCache bool) (*InputCache, error) {
|
||||
if kvSize/numSlots < 1 {
|
||||
return nil, fmt.Errorf("must have at least one kv cache entry per parallel sequence (kv: %v parallel: %v)", kvSize, numSlots)
|
||||
}
|
||||
|
||||
slots := make([]InputCacheSlot, numSlots)
|
||||
|
||||
for i := range slots {
|
||||
slots[i] = InputCacheSlot{
|
||||
Id: i,
|
||||
Inputs: make([]input, 0),
|
||||
}
|
||||
}
|
||||
|
||||
return &InputCache{
|
||||
numCtx: kvSize / numSlots,
|
||||
slots: slots,
|
||||
multiUserCache: multiUserCache,
|
||||
lc: lc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Locking: Operations on InputCacheSlot (including finding one
|
||||
// through LoadCacheSlot) require a lock to be be held that serializes
|
||||
// these operations with each other and llama.Decode
|
||||
|
||||
type InputCacheSlot struct {
|
||||
// Index in the KV cache
|
||||
Id int
|
||||
|
||||
// Inputs that are stored in the KV cache
|
||||
Inputs []input
|
||||
|
||||
// is this cache actively being processed as part of a sequence?
|
||||
InUse bool
|
||||
|
||||
// last time this cache was used (as of start of processing)
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *InputCache) LoadCacheSlot(prompt []input, cachePrompt bool) (*InputCacheSlot, []input, error) {
|
||||
var slot *InputCacheSlot
|
||||
var numPast int
|
||||
var err error
|
||||
|
||||
// In single-user scenarios, the longest cache slot works fine for getting good input
|
||||
// cache hit rates and it reuses the same VRAM over and over again, which is good for
|
||||
// GPU performance in situations where we miss the input cache.
|
||||
// For multiple users, the "best" cache slot produces better input cache hit rates
|
||||
// at the cost of worse performance when we miss the input cache (because it causes
|
||||
// GPU L2 cache misses due to spreading out accesses across VRAM).
|
||||
if !c.multiUserCache {
|
||||
slot, numPast, err = c.findLongestCacheSlot(prompt)
|
||||
} else {
|
||||
slot, numPast, err = c.findBestCacheSlot(prompt)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if !cachePrompt {
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slot.InUse = true
|
||||
slot.lastUsed = time.Now()
|
||||
|
||||
if numPast == len(prompt) {
|
||||
// Leave one input to sample so we can get a response
|
||||
numPast--
|
||||
}
|
||||
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numPast, -1) {
|
||||
// Some models don't support partial erasure
|
||||
c.lc.KvCacheSeqRm(slot.Id, 0, -1)
|
||||
numPast = 0
|
||||
}
|
||||
|
||||
slog.Debug("loading cache slot", "id", slot.Id, "cache", len(slot.Inputs), "prompt", len(prompt),
|
||||
"used", numPast, "remaining", len(prompt)-numPast)
|
||||
|
||||
prompt = prompt[numPast:]
|
||||
slot.Inputs = slot.Inputs[:numPast]
|
||||
|
||||
return slot, prompt, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findLongestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||||
longest := -1
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
for i, s := range c.slots {
|
||||
if s.InUse {
|
||||
continue
|
||||
}
|
||||
|
||||
count := countCommonPrefix(s.Inputs, prompt)
|
||||
if count > longest {
|
||||
longest = count
|
||||
longestSlot = &c.slots[i]
|
||||
}
|
||||
}
|
||||
|
||||
if longestSlot == nil {
|
||||
return nil, 0, errors.New("no available cache slots")
|
||||
}
|
||||
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
func (c *InputCache) findBestCacheSlot(prompt []input) (*InputCacheSlot, int, error) {
|
||||
oldest := time.Now()
|
||||
var oldestSlot *InputCacheSlot
|
||||
|
||||
longest := -1
|
||||
var longestSlot *InputCacheSlot
|
||||
|
||||
for i, s := range c.slots {
|
||||
count := countCommonPrefix(s.Inputs, prompt)
|
||||
if count > longest {
|
||||
longest = count
|
||||
longestSlot = &c.slots[i]
|
||||
}
|
||||
|
||||
if s.lastUsed.Compare(oldest) < 0 && !s.InUse {
|
||||
oldest = s.lastUsed
|
||||
oldestSlot = &c.slots[i]
|
||||
}
|
||||
}
|
||||
|
||||
if longest == len(longestSlot.Inputs) && !longestSlot.InUse {
|
||||
return longestSlot, longest, nil
|
||||
}
|
||||
|
||||
if oldestSlot.InUse {
|
||||
return nil, 0, errors.New("no available cache slots")
|
||||
}
|
||||
|
||||
if len(oldestSlot.Inputs) != 0 {
|
||||
slog.Debug("evicting cache slot", "id", oldestSlot.Id, "inputs", len(oldestSlot.Inputs),
|
||||
"used", oldestSlot.lastUsed)
|
||||
}
|
||||
|
||||
if longest > 0 && longestSlot != oldestSlot {
|
||||
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
|
||||
len(longestSlot.Inputs))
|
||||
oldestSlot.Inputs = make([]input, longest)
|
||||
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
|
||||
// This is only nil for unit tests
|
||||
if c.lc != nil {
|
||||
c.lc.KvCacheSeqRm(oldestSlot.Id, 0, -1)
|
||||
c.lc.KvCacheSeqCp(longestSlot.Id, oldestSlot.Id, 0, longest)
|
||||
}
|
||||
}
|
||||
|
||||
return oldestSlot, longest, nil
|
||||
}
|
||||
|
||||
func countCommonPrefix(a []input, b []input) int {
|
||||
var count int
|
||||
|
||||
for i := range a {
|
||||
if i >= len(b) {
|
||||
break
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(a[i], b[i]) {
|
||||
break
|
||||
}
|
||||
|
||||
count++
|
||||
}
|
||||
|
||||
return count
|
||||
}
|
||||
|
||||
func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
|
||||
targetFree := (c.numCtx - numKeep) / 2
|
||||
targetFree = max(targetFree, 1)
|
||||
|
||||
currentFree := c.numCtx - inputLen
|
||||
discard := targetFree - currentFree
|
||||
|
||||
if discard < 0 {
|
||||
discard = 0
|
||||
}
|
||||
|
||||
return discard
|
||||
}
|
||||
|
||||
// Frees up space in the KV cache by deleting the oldest half of history and shifting
|
||||
// the newest half into that space (saving numKeep inputs at the beginning).
|
||||
//
|
||||
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
|
||||
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
|
||||
if numKeep >= c.numCtx {
|
||||
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
|
||||
}
|
||||
|
||||
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
|
||||
|
||||
if discard <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
|
||||
"keep", numKeep, "discard", discard)
|
||||
|
||||
// TODO (jessegross): KV cache removal can fail for certain types of models
|
||||
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
|
||||
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
|
||||
}
|
||||
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
|
||||
|
||||
for i := numKeep + discard; i < len(slot.Inputs); i++ {
|
||||
slot.Inputs[i-discard] = slot.Inputs[i]
|
||||
}
|
||||
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
|
||||
|
||||
return nil
|
||||
}
|
||||
292
runner/oldrunner/cache_test.go
Normal file
292
runner/oldrunner/cache_test.go
Normal file
@@ -0,0 +1,292 @@
|
||||
package oldrunner
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCountCommon(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
t1 []input
|
||||
t2 []input
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Equal",
|
||||
t1: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 3,
|
||||
},
|
||||
{
|
||||
name: "Prefix",
|
||||
t1: []input{{token: 1}},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Embeddings Prefix",
|
||||
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
|
||||
t2: []input{{embed: []float32{0.1, 0.2, 0.3}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "Embeddings Prefix Partial",
|
||||
t1: []input{{embed: []float32{0.1, 0.2, 0.3}}},
|
||||
t2: []input{{embed: []float32{0.1, 0.2}}, {embed: []float32{0.4, 0.5, 0.6}}, {embed: []float32{0.7}}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Mixed",
|
||||
t1: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}},
|
||||
t2: []input{{token: 1}, {embed: []float32{0.2, 0.3, 0.4}}, {token: 5}},
|
||||
expected: 2,
|
||||
},
|
||||
{
|
||||
name: "Empty",
|
||||
t1: []input{},
|
||||
t2: []input{{token: 1}, {token: 2}, {token: 3}},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "Both Empty",
|
||||
t1: []input{},
|
||||
t2: []input{},
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := countCommonPrefix(tt.t1, tt.t2)
|
||||
if result != tt.expected {
|
||||
t.Errorf("countCommonPrefix(%v, %v): have %v; want %v", tt.t1, tt.t2, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFindCacheSlot(t *testing.T) {
|
||||
type expected struct {
|
||||
result int
|
||||
len int
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
cache InputCache
|
||||
prompt []input
|
||||
longest expected
|
||||
best expected
|
||||
}{
|
||||
{
|
||||
name: "Empty",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 0, len: 0},
|
||||
},
|
||||
{
|
||||
name: "Extend",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
longest: expected{result: 1, len: 2},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
{
|
||||
name: "New",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
{
|
||||
name: "Fork",
|
||||
cache: InputCache{
|
||||
slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{},
|
||||
InUse: false,
|
||||
lastUsed: time.Time{},
|
||||
},
|
||||
},
|
||||
},
|
||||
prompt: []input{{token: 1}},
|
||||
longest: expected{result: 0, len: 1},
|
||||
best: expected{result: 1, len: 1},
|
||||
},
|
||||
{
|
||||
name: "Evict",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 2}, {token: 3}},
|
||||
longest: expected{result: 0, len: 0},
|
||||
best: expected{result: 1, len: 0},
|
||||
},
|
||||
{
|
||||
name: "In use",
|
||||
cache: InputCache{slots: []InputCacheSlot{
|
||||
{
|
||||
Id: 0,
|
||||
Inputs: []input{{token: 1}, {token: 2}},
|
||||
InUse: true,
|
||||
lastUsed: time.Now().Add(-time.Second),
|
||||
},
|
||||
{
|
||||
Id: 1,
|
||||
Inputs: []input{{token: 1}},
|
||||
InUse: false,
|
||||
lastUsed: time.Now().Add(-2 * time.Second),
|
||||
},
|
||||
}},
|
||||
prompt: []input{{token: 1}, {token: 2}},
|
||||
longest: expected{result: 1, len: 1},
|
||||
best: expected{result: 1, len: 2},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("Longest-"+tt.name, func(t *testing.T) {
|
||||
result, resultLen, err := tt.cache.findLongestCacheSlot(tt.prompt)
|
||||
if err != nil {
|
||||
t.Errorf("findLongestCacheSlot: err %v", err)
|
||||
} else if result.Id != tt.longest.result || resultLen != tt.longest.len {
|
||||
t.Errorf("findLongestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||
result.Id, tt.longest.result, resultLen, tt.longest.len)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run("Best-"+tt.name, func(t *testing.T) {
|
||||
result, resultLen, err := tt.cache.findBestCacheSlot(tt.prompt)
|
||||
if err != nil {
|
||||
t.Errorf("findBestCacheSlot: err %v", err)
|
||||
} else if result.Id != tt.best.result || resultLen != tt.best.len {
|
||||
t.Errorf("findBestCacheSlot: slot have %v, want %v len have %v, want %v",
|
||||
result.Id, tt.best.result, resultLen, tt.best.len)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestShiftDiscard(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
numCtx int
|
||||
numKeep int
|
||||
inputLen int
|
||||
expected int
|
||||
}{
|
||||
{
|
||||
name: "Shift",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 2048,
|
||||
expected: 1021,
|
||||
},
|
||||
{
|
||||
name: "Max Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 2048,
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "No Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 0,
|
||||
inputLen: 2048,
|
||||
expected: 1024,
|
||||
},
|
||||
{
|
||||
name: "Truncate",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 5000,
|
||||
expected: 3973,
|
||||
},
|
||||
{
|
||||
name: "Truncate Keep",
|
||||
numCtx: 2048,
|
||||
numKeep: 2047,
|
||||
inputLen: 5000,
|
||||
expected: 2953,
|
||||
},
|
||||
{
|
||||
name: "No Op",
|
||||
numCtx: 2048,
|
||||
numKeep: 5,
|
||||
inputLen: 512,
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := InputCache{numCtx: tt.numCtx}
|
||||
result := c.ShiftDiscard(tt.inputLen, tt.numKeep)
|
||||
if result != tt.expected {
|
||||
t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
183
runner/oldrunner/image.go
Normal file
183
runner/oldrunner/image.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package oldrunner
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
)
|
||||
|
||||
const imageCacheSize = 4
|
||||
|
||||
type ImageContext struct {
|
||||
// mu is required to be held when generating embeddings or accessing the cache
|
||||
mu sync.Mutex
|
||||
|
||||
clip *llama.ClipContext
|
||||
mllama *llama.MllamaContext
|
||||
|
||||
// cache of images to embeddings
|
||||
images []imageCache
|
||||
imageHash maphash.Hash
|
||||
}
|
||||
|
||||
func NewImageContext(llamaContext *llama.Context, modelPath string) (*ImageContext, error) {
|
||||
arch, err := llama.GetModelArch(modelPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to determine vision architecture: %w (%s)", err, modelPath)
|
||||
}
|
||||
|
||||
var c ImageContext
|
||||
if arch == "clip" {
|
||||
c.clip, err = llama.NewClipContext(llamaContext, modelPath)
|
||||
} else if arch == "mllama" {
|
||||
c.mllama, err = llama.NewMllamaContext(llamaContext, modelPath)
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown vision model architecture: %s", arch)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
c.images = make([]imageCache, imageCacheSize)
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func (c *ImageContext) Free(modelPath string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if c.clip != nil {
|
||||
c.clip.Free()
|
||||
}
|
||||
if c.mllama != nil {
|
||||
c.mllama.Free()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte, aspectRatioId int) ([][]float32, error) {
|
||||
if c == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if len(data) <= 0 {
|
||||
return nil, errors.New("received zero length image")
|
||||
}
|
||||
|
||||
hash := c.hashImage(data)
|
||||
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
embed, err := c.findImage(hash)
|
||||
if err != nil {
|
||||
if c.mllama != nil {
|
||||
embed, err = c.mllama.NewEmbed(llamaContext, data, aspectRatioId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if c.clip != nil {
|
||||
embed, err = c.clip.NewEmbed(llamaContext, data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("received image but vision model not loaded")
|
||||
}
|
||||
|
||||
c.addImage(hash, embed)
|
||||
}
|
||||
|
||||
return embed, nil
|
||||
}
|
||||
|
||||
func (c *ImageContext) BatchSize(configuredBatchSize int) int {
|
||||
// If images are not supported, we don't need to allocate embedding batches
|
||||
if c == nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
// Mllama maps an image to 1 embedding token (llava creates many tokens)
|
||||
// and doesn't support more than a single image per request.
|
||||
// The embeddings are large (100 MB), so allocating a big batch can fail
|
||||
// on some systems
|
||||
if c.mllama != nil {
|
||||
return 1
|
||||
}
|
||||
|
||||
return configuredBatchSize
|
||||
}
|
||||
|
||||
func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int {
|
||||
if c != nil && c.mllama != nil {
|
||||
return c.mllama.EmbedSize(llamaContext)
|
||||
} else {
|
||||
return llamaContext.Model().NEmbd()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ImageContext) NeedCrossAttention(inputs ...input) bool {
|
||||
if c == nil || c.mllama == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return slices.ContainsFunc(inputs, func(input input) bool {
|
||||
return input.embed != nil
|
||||
})
|
||||
}
|
||||
|
||||
type imageCache struct {
|
||||
key uint64
|
||||
val [][]float32
|
||||
lastUsed time.Time
|
||||
}
|
||||
|
||||
func (c *ImageContext) hashImage(image []byte) uint64 {
|
||||
c.imageHash.Reset()
|
||||
_, _ = c.imageHash.Write(image)
|
||||
return c.imageHash.Sum64()
|
||||
}
|
||||
|
||||
var errImageNotFound = errors.New("image not found in cache")
|
||||
|
||||
func (c *ImageContext) findImage(hash uint64) ([][]float32, error) {
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
slog.Debug("loading image embeddings from cache", "entry", i)
|
||||
c.images[i].lastUsed = time.Now()
|
||||
return c.images[i].val, nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errImageNotFound
|
||||
}
|
||||
|
||||
func (c *ImageContext) addImage(hash uint64, embed [][]float32) {
|
||||
best := time.Now()
|
||||
var bestImage int
|
||||
|
||||
for i := range c.images {
|
||||
if c.images[i].key == hash {
|
||||
bestImage = i
|
||||
break
|
||||
}
|
||||
|
||||
if c.images[i].lastUsed.Compare(best) < 0 {
|
||||
best = c.images[i].lastUsed
|
||||
bestImage = i
|
||||
}
|
||||
}
|
||||
|
||||
slog.Debug("storing image embeddings in cache", "entry", bestImage, "used", c.images[bestImage].lastUsed)
|
||||
c.images[bestImage].key = hash
|
||||
c.images[bestImage].val = embed
|
||||
c.images[bestImage].lastUsed = time.Now()
|
||||
}
|
||||
80
runner/oldrunner/image_test.go
Normal file
80
runner/oldrunner/image_test.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package oldrunner
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestImageCache(t *testing.T) {
|
||||
cache := ImageContext{images: make([]imageCache, 4)}
|
||||
|
||||
valA := [][]float32{{0.1, 0.2}, {0.3}}
|
||||
valB := [][]float32{{0.4}, {0.5}, {0.6}}
|
||||
valC := [][]float32{{0.7}}
|
||||
valD := [][]float32{{0.8}}
|
||||
valE := [][]float32{{0.9}}
|
||||
|
||||
// Empty cache
|
||||
result, err := cache.findImage(0x5adb61d31933a946)
|
||||
if err != errImageNotFound {
|
||||
t.Errorf("found result in empty cache: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert A
|
||||
cache.addImage(0x5adb61d31933a946, valA)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Insert B
|
||||
cache.addImage(0x011551369a34a901, valB)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Replace B with C
|
||||
cache.addImage(0x011551369a34a901, valC)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if !reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
|
||||
// Evict A
|
||||
cache.addImage(0x756b218a517e7353, valB)
|
||||
cache.addImage(0x75e5e8d35d7e3967, valD)
|
||||
cache.addImage(0xd96f7f268ca0646e, valE)
|
||||
|
||||
result, err = cache.findImage(0x5adb61d31933a946)
|
||||
if reflect.DeepEqual(result, valA) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x756b218a517e7353)
|
||||
if !reflect.DeepEqual(result, valB) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x011551369a34a901)
|
||||
if !reflect.DeepEqual(result, valC) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0x75e5e8d35d7e3967)
|
||||
if !reflect.DeepEqual(result, valD) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
result, err = cache.findImage(0xd96f7f268ca0646e)
|
||||
if !reflect.DeepEqual(result, valE) {
|
||||
t.Errorf("failed to find expected value: result %v, err %v", result, err)
|
||||
}
|
||||
}
|
||||
1001
runner/oldrunner/runner.go
Normal file
1001
runner/oldrunner/runner.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user