Compare commits

..

9 Commits

Author SHA1 Message Date
Jeffrey Morgan
1579c4f06d build: install binutils alongside gcc in Dockerfile (#9475) 2025-03-03 01:20:49 -08:00
Blake Mizerany
3519dd1c6e server/internal/client/ollama: hold DiskCache on Registry (#9463)
Previously, using a Registry required a DiskCache to be passed in for
use in various methods. This was a bit cumbersome, as the DiskCache is
required for most operations, and the DefaultCache is used in most of
those cases. This change makes the DiskCache an optional field on the
Registry struct.

This also changes DefaultCache to initialize on first use. This is to
not burden clients with the cost of creating a new cache per use, or
having to hold onto a cache for the lifetime of the Registry.

Also, slip in some minor docs updates for Trace.
2025-03-02 20:55:44 -08:00
Jeffrey Morgan
e41c4cbea7 build: install ccache manually in Dockerfile (#9464)
Reverts ccache installation to be done manually via curl instead of
using the dnf package manager as this has side effects of prepending
ccache's install directory to the front of the PATH
2025-03-02 16:48:31 -08:00
Blake Mizerany
ee048b76d4 server/internal/client/ollama: handle extended names in client/ollama (#9454)
The extended name format is a superset of the name format that only the
client needs to know about, not the server or other dependents of the
name package, so move the split logic into the client package.

Also, take advantage of knowing about the extended name format to allow
the client to use the extended name format when unlinking to verify they
are unlinking the manifest with the content they intend.
2025-03-02 13:30:41 -08:00
Soulter
af68d60a58 readme: add AstrBot to community integrations (#9442) 2025-03-01 21:58:34 -08:00
Jesse Gross
21aa666a1e ml: Enable support for flash attention
The GGML flash attention kernel has specific requirements for
padding and permutation. This adds support to the KV cache
for conforming to these requirements so that flash attention
can be enabled.

Flash attention can be used in the same situations as the llama
engine and is enabled by the user in the same way.
2025-03-01 20:53:23 -08:00
Jesse Gross
ee141cc821 ml: Empty tensor constructor for tensors
In cases where we allocate a tensor and then fully overwrite it with
copied data, it is wasteful to first zero out the memory.
2025-03-01 20:53:23 -08:00
Jesse Gross
55e5776c44 ggml-backend: Store parent backend as part of tensor
It can be important for a tensor to know what backend it came from -
for example, to know if flash attention is enabled.
2025-03-01 20:53:23 -08:00
Jesse Gross
854a9195f3 attention: Remove unnecessary contiguous operations
Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
2025-03-01 20:53:23 -08:00
22 changed files with 663 additions and 255 deletions

View File

@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH

View File

@@ -387,6 +387,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
### Cloud

View File

@@ -29,6 +29,17 @@ type Cache interface {
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters

View File

@@ -22,6 +22,9 @@ type Causal struct {
Capacity int32
windowSize int32
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
c.DType = dtype
c.Capacity = capacity
c.cells = make([]cacheCell, capacity)
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.cacheCtx = backend.NewContext()
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() {
c.cacheCtx.Close()
}
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention
len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
}
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for _, obj := range objs {
if obj == nil {
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i := range c.keys {
if c.keys[i] == nil {
continue
}
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
key := c.keys[i]
ctx.Forward(srcView.Copy(ctx, dstView))
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
pendingLen++
break
} else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
}
if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
@@ -305,35 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
c.curMask.Dim(0),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
)
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
)
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
kHeadDim := key.Dim(0)
vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))),
value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))),
)
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -389,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue
}
key = key.View(ctx, key.Stride(2)*seqRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)

View File

@@ -309,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
type testContext struct{}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
@@ -322,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
@@ -391,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -468,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{}
view := context.Zeros(t.dtype, s...).(*testTensor)
view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view

View File

@@ -1,6 +1,8 @@
package kvcache
import (
"fmt"
"github.com/ollama/ollama/ml"
)
@@ -11,6 +13,9 @@ import (
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext()
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
}
@@ -75,9 +100,13 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
}
ctx.Forward(

View File

@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
}
}
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()

View File

@@ -27,6 +27,35 @@ type Backend interface {
SystemInfo() string
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU
@@ -40,6 +69,9 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
@@ -61,6 +93,7 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
@@ -116,6 +149,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
@@ -170,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16:
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)

View File

@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
})
type Backend struct {
flashAttention bool
meta *fs.GGML
cpus, gpus []Context
tensors map[string]*Context
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
return &Backend{
meta: meta,
cpus: cpus,
gpus: gpus,
flashAttention: params.FlashAttention,
meta: meta,
cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
for _, c := range append(b.gpus, b.cpus...) {
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
return &Tensor{t: t}
return &Tensor{b: b, t: t}
}
}
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
}
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
}
type Context struct {
b *Backend
ctx *C.struct_ggml_context
@@ -300,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
if len(shape) < 1 || len(shape) > 4 {
panic("unsupported number of dimensions")
}
@@ -314,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
var t *C.struct_ggml_tensor
switch dtype {
case ml.DTypeF32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
panic("unsupported dtype")
}
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_set_zero(t)
return &Tensor{t: t}
if zero {
C.ggml_set_zero(t)
}
return &Tensor{b: ctx.b, t: t}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, false, shape)
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, true, shape)
}
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
@@ -335,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
if n == 0 {
var shape C.int64_t = 0
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
return &Tensor{t: t}, nil
return &Tensor{b: ctx.b, t: t}, nil
}
for _, v := range shape {
@@ -350,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
return &Tensor{t: t}, nil
return &Tensor{b: ctx.b, t: t}, nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
@@ -368,6 +389,7 @@ func (c *Context) Close() {
}
type Tensor struct {
b *Backend
t *C.struct_ggml_tensor
sync func()
}
@@ -434,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -448,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
}
}
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -475,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {
tt = tt.Add(ctx, b)
}
@@ -489,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -498,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -508,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -528,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
case 2:
return &Tensor{
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
case 4:
return &Tensor{
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
default:
@@ -549,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
}
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
}
}
@@ -571,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -579,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
@@ -590,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 5:
return &Tensor{
b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
@@ -597,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 7:
return &Tensor{
b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
@@ -613,7 +657,7 @@ const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
@@ -622,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
@@ -639,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
}
}
@@ -661,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t
}
kq := key.MulmatFullPrec(ctx, t)
kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
if t.b.flashAttention {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (b *Backend) SystemInfo() string {

View File

@@ -3,6 +3,7 @@ package nn
import (
"fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
@@ -11,40 +12,50 @@ import (
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
var mask ml.Tensor
if cache != nil {
key, value, mask = cache.Get(ctx)
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)

View File

@@ -81,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@@ -43,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil
}

View File

@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
// This will only get called for layers in the causal cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor
var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)

View File

@@ -818,7 +818,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention")
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
@@ -863,7 +863,6 @@ func Execute(args []string) error {
}
// TODO(jessegross): Parameters that need to be implemented:
// flash-attn
// no-mmap
// mlock
@@ -878,10 +877,11 @@ func Execute(args []string) error {
}
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
FlashAttention: *flashAttention,
}
server.ready.Add(1)

View File

@@ -27,6 +27,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -73,19 +74,22 @@ const (
DefaultMaxChunkSize = 8 << 20
)
// DefaultCache returns a new disk cache for storing models. If the
// OLLAMA_MODELS environment variable is set, it uses that directory;
// otherwise, it uses $HOME/.ollama/models.
func DefaultCache() (*blob.DiskCache, error) {
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
home, _ := os.UserHomeDir()
home = cmp.Or(home, ".")
dir = filepath.Join(home, ".ollama", "models")
}
return blob.Open(dir)
})
// DefaultCache returns the default cache used by the registry. It is
// configured from the OLLAMA_MODELS environment variable, or defaults to
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
// it uses the current working directory.
func DefaultCache() (*blob.DiskCache, error) {
return defaultCache()
}
// Error is the standard error returned by Ollama APIs. It can represent a
@@ -168,6 +172,10 @@ func CompleteName(name string) string {
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
// Cache is the cache used to store models. If nil, [DefaultCache] is
// used.
Cache *blob.DiskCache
// UserAgent is the User-Agent header to send with requests to the
// registry. If empty, the User-Agent is determined by HTTPClient.
UserAgent string
@@ -206,18 +214,28 @@ type Registry struct {
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
// Mask, if set, is the name used to convert non-fully qualified names
// to fully qualified names. If empty, [DefaultMask] is used.
Mask string
}
func (r *Registry) completeName(name string) names.Name {
func (r *Registry) cache() (*blob.DiskCache, error) {
if r.Cache != nil {
return r.Cache, nil
}
return defaultCache()
}
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
return names.Merge(names.Parse(name), mask)
n := names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
@@ -278,12 +296,17 @@ type PushParams struct {
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
c, err := r.cache()
if err != nil {
return err
}
m, err := r.ResolveLocal(cmp.Or(p.From, name))
if err != nil {
return err
}
@@ -306,7 +329,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name, r.Mask)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@@ -399,8 +422,8 @@ func canRetry(err error) bool {
// chunks of the specified size, and then reassembled and verified. This is
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name, r.Mask)
func (r *Registry) Pull(ctx context.Context, name string) error {
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
@@ -413,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
}
c, err := r.cache()
if err != nil {
return err
}
exists := func(l *Layer) bool {
info, err := c.Get(l.Digest)
return err == nil && info.Size == l.Size
@@ -550,10 +578,14 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
n := r.completeName(name)
if !n.IsFullyQualified() {
return false, fmt.Errorf("%w: %q", ErrNameInvalid, name)
func (r *Registry) Unlink(name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
return false, err
}
c, err := r.cache()
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
@@ -626,14 +658,18 @@ type Layer struct {
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.Split] but the scheme is ignored.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name, r.Mask)
// ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
c, err := r.cache()
if err != nil {
return nil, err
}
if !d.IsValid() {
// No digest, so resolve the manifest by name.
d, err = c.Resolve(n.String())
if err != nil {
return nil, err
@@ -655,7 +691,7 @@ func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, erro
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name, r.Mask)
scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@@ -859,7 +895,7 @@ var supportedSchemes = []string{
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseName parses and validates an extended name, returning the scheme, name,
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
@@ -870,8 +906,8 @@ var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Jo
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := names.Split(s)
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
@@ -894,13 +930,33 @@ func parseName(s string, mask string) (scheme string, _ names.Name, _ blob.Diges
}
}
maskName := defaultMask
if mask != "" {
maskName = names.Parse(mask)
}
n := names.Merge(names.Parse(name), maskName)
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, fmt.Errorf("%w: %q", ErrNameInvalid, s)
n, err := r.parseName(name)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
return scheme, n, d, nil
}
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}

View File

@@ -2,6 +2,7 @@ package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
@@ -72,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
// To simulate a network error, pass a handler that returns a 499 status code.
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
t.Helper()
c, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
@@ -85,13 +87,14 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
r := &Registry{
Cache: c,
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
link := func(name string, manifest string) {
_, n, _, err := parseName(name, r.Mask)
n, err := r.parseName(name)
if err != nil {
panic(err)
}
@@ -151,55 +154,55 @@ func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
}
func TestPushZero(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "empty", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "empty", nil)
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
}
}
func TestPushSingle(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "single", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "single", nil)
testutil.Check(t, err)
}
func TestPushMultiple(t *testing.T) {
rc, c := newClient(t, okHandler)
err := rc.Push(t.Context(), c, "multiple", nil)
rc, _ := newClient(t, okHandler)
err := rc.Push(t.Context(), "multiple", nil)
testutil.Check(t, err)
}
func TestPushNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Errorf("unexpected request: %v", r)
})
err := rc.Push(t.Context(), c, "notfound", nil)
err := rc.Push(t.Context(), "notfound", nil)
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
}
}
func TestPushNullLayer(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "null", nil)
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "null", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
}
func TestPushSizeMismatch(t *testing.T) {
rc, c := newClient(t, nil)
rc, _ := newClient(t, nil)
ctx, _ := withTraceUnexpected(t.Context())
got := rc.Push(ctx, c, "sizemismatch", nil)
got := rc.Push(ctx, "sizemismatch", nil)
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
t.Errorf("err = %v; want size mismatch", got)
}
}
func TestPushInvalid(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Push(t.Context(), c, "invalid", nil)
rc, _ := newClient(t, nil)
err := rc.Push(t.Context(), "invalid", nil)
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -207,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
func TestPushExistsAtRemote(t *testing.T) {
var pushed bool
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/uploads/") {
if !pushed {
// First push. Return an uploadURL.
@@ -235,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
check := testutil.Checker(t)
err := rc.Push(ctx, c, "single", nil)
err := rc.Push(ctx, "single", nil)
check(err)
if !errors.Is(errors.Join(errs...), nil) {
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
}
err = rc.Push(ctx, c, "single", nil)
err = rc.Push(ctx, "single", nil)
check(err)
}
func TestPushRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(500)
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
return
}
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
checkErrCode(t, got, 500, "blob_error")
}
func TestPushLocationError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Location", ":///x")
w.WriteHeader(http.StatusAccepted)
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
wantContains := "invalid upload URL"
if got == nil || !strings.Contains(got.Error(), wantContains) {
t.Errorf("err = %v; want to contain %v", got, wantContains)
@@ -271,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
}
func TestPushUploadRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if r.Host == "blob.store" {
w.WriteHeader(499) // force RoundTrip error on upload
return
}
w.Header().Set("Location", "http://blob.store/blobs/123")
})
got := rc.Push(t.Context(), c, "single", nil)
got := rc.Push(t.Context(), "single", nil)
if !errors.Is(got, errRoundTrip) {
t.Errorf("got = %v; want %v", got, errRoundTrip)
}
@@ -294,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
os.Remove(c.GetFile(l.Digest))
},
})
got := rc.Push(ctx, c, "single", nil)
got := rc.Push(ctx, "single", nil)
if !errors.Is(got, fs.ErrNotExist) {
t.Errorf("got = %v; want fs.ErrNotExist", got)
}
}
func TestPushCommitRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
panic("unexpected")
}
w.WriteHeader(499) // force RoundTrip error
})
err := rc.Push(t.Context(), c, "zero", nil)
err := rc.Push(t.Context(), "zero", nil)
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -321,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, c := newClient(t, nil)
err := rc.Pull(t.Context(), c, "://")
rc, _ := newClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
}
@@ -337,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), c, "x")
err := rc.Pull(t.Context(), "x")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
@@ -363,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
_, err := rc.ResolveLocal(c, "model")
_, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), c, "model")
err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := rc.ResolveLocal(c, "model")
mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -399,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
@@ -422,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, c, "single")
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{6}
@@ -435,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), c, "notfound")
err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), c, "single")
err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), c, "single")
err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
@@ -511,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
// Check that we pull all layers that we can.
err := rc.Pull(ctx, c, "mixed")
err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
@@ -529,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
}
func TestRegistryPullChunking(t *testing.T) {
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
if r.URL.Host != "blob.store" {
// The production registry redirects to the blob store.
@@ -567,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
},
})
err := rc.Pull(ctx, c, "remote")
err := rc.Pull(ctx, "remote")
testutil.Check(t, err)
want := []int64{0, 3, 6}
@@ -709,25 +712,16 @@ func TestErrorUnmarshal(t *testing.T) {
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameErrors(t *testing.T) {
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{
{"x", nil, ""},
{"x@", nil, ""},
{"", ErrNameInvalid, `invalid or missing name: ""`},
{"://", ErrNameInvalid, `invalid or missing name: "://"`},
{"x://", ErrNameInvalid, `unsupported scheme: "x": supported schemes are http, https, https+insecure`},
{"@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
{"x@sha123-1234", ErrNameInvalid, `invalid digest: "sha123-1234"`},
}
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := parseName(tt.name, DefaultMask)
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
@@ -736,3 +730,89 @@ func TestParseNameErrors(t *testing.T) {
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}

View File

@@ -6,6 +6,9 @@ import (
// Trace is a set of functions that are called to report progress during blob
// downloads and uploads.
//
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
// and [Registry.Pull].
type Trace struct {
// Update is called during [Registry.Push] and [Registry.Pull] to
// report the progress of blob uploads and downloads.

View File

@@ -63,25 +63,28 @@ func main() {
}
flag.Parse()
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
ctx := context.Background()
err = func() error {
err := func() error {
switch cmd := flag.Arg(0); cmd {
case "pull":
return cmdPull(ctx, rc, c)
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPull(ctx, rc)
case "push":
return cmdPush(ctx, rc, c)
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
return cmdPush(ctx, rc)
case "import":
c, err := ollama.DefaultCache()
if err != nil {
log.Fatal(err)
}
return cmdImport(ctx, c)
default:
if cmd == "" {
@@ -99,7 +102,7 @@ func main() {
}
}
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
model := flag.Arg(1)
if model == "" {
flag.Usage()
@@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
errc := make(chan error)
go func() {
errc <- rc.Pull(ctx, c, model)
errc <- rc.Pull(ctx, model)
}()
t := time.NewTicker(time.Second)
@@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
}
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
args := flag.Args()[1:]
flag := flag.NewFlagSet("push", flag.ExitOnError)
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
@@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
from := cmp.Or(*flagFrom, model)
m, err := rc.ResolveLocal(c, from)
m, err := rc.ResolveLocal(from)
if err != nil {
return err
}
@@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
},
})
return rc.Push(ctx, c, model, &ollama.PushParams{
return rc.Push(ctx, model, &ollama.PushParams{
From: from,
})
}

View File

@@ -11,7 +11,6 @@ import (
"log/slog"
"net/http"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
)
@@ -27,7 +26,6 @@ import (
// directly to the blob disk cache.
type Local struct {
Client *ollama.Registry // required
Cache *blob.DiskCache // required
Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by
@@ -199,7 +197,7 @@ func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if err != nil {
return err
}
ok, err := s.Client.Unlink(s.Cache, p.model())
ok, err := s.Client.Unlink(p.model())
if err != nil {
return err
}

View File

@@ -42,10 +42,10 @@ func newTestServer(t *testing.T) *Local {
t.Fatal(err)
}
rc := &ollama.Registry{
Cache: c,
HTTPClient: panicOnRoundTrip,
}
l := &Local{
Cache: c,
Client: rc,
Logger: testutil.Slogger(t),
}
@@ -87,7 +87,7 @@ func TestServerDelete(t *testing.T) {
s := newTestServer(t)
_, err := s.Client.ResolveLocal(s.Cache, "smol")
_, err := s.Client.ResolveLocal("smol")
check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
@@ -95,7 +95,7 @@ func TestServerDelete(t *testing.T) {
t.Fatalf("Code = %d; want 200", got.Code)
}
_, err = s.Client.ResolveLocal(s.Cache, "smol")
_, err = s.Client.ResolveLocal("smol")
if err == nil {
t.Fatal("expected smol to have been deleted")
}

View File

@@ -34,7 +34,6 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -1129,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
}
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
@@ -1197,7 +1196,6 @@ func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Ha
// wrap old with new
rs := &registry.Local{
Cache: c,
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
@@ -1258,16 +1256,12 @@ func Serve(ln net.Listener) error {
s := &Server{addr: ln.Addr()}
c, err := ollama.DefaultCache()
if err != nil {
return err
}
rc, err := ollama.DefaultRegistry()
if err != nil {
return err
}
h, err := s.GenerateRoutes(c, rc)
h, err := s.GenerateRoutes(rc)
if err != nil {
return err
}

View File

@@ -23,7 +23,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
@@ -490,11 +489,6 @@ func TestRoutes(t *testing.T) {
modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir)
c, err := blob.Open(modelsDir)
if err != nil {
t.Fatalf("failed to open models dir: %v", err)
}
rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended
@@ -511,7 +505,7 @@ func TestRoutes(t *testing.T) {
}
s := &Server{}
router, err := s.GenerateRoutes(c, rc)
router, err := s.GenerateRoutes(rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)
}