Compare commits

..

8 Commits

Author SHA1 Message Date
ofrancon
4383a3ab7a readme: add Neuro SAN to community integrations (#12109) 2025-08-28 12:27:13 -07:00
Jesse Gross
9d97e6a9f1 ggml: Avoid allocating CUDA primary context on unused GPUs
The recent memory management changes caused all GPUs to be visible
to the runner, regardless of whether they are ultimately used. This
caused CUDA devices to allocate a primary context (~300 MB VRAM) on
each GPU, for each model. This is unnecessary, so we can both avoid
touching GPUs that we exclude in the early stage of allocation and
freeing the memory for any that we touch but don't use.

The issue will continue to exist for the old engine, since it touches
all devices during initialization.
2025-08-27 16:24:18 -07:00
Michael Yang
1081532430 fix keep alive (#12041) 2025-08-27 11:51:25 -07:00
Michael Yang
59412fbb43 convert(gptoss): mxfp4 to ggml layout to avoid jit conversion (#12018)
* convert: return bytes written

* ggml flavor mxfp4

* simplify jit conversion

* comment
2025-08-26 16:41:02 -07:00
Michael Yang
86834a2797 convert: fix tensor sorting (#12015)
there's two bugs here.

1. the check for a layer id is incorrect and should be >= 0 since layer
   0 is valid
2. if both tensors have an layer identifier, it will only compare the
   layer id which will return 0 if the tensors are in the same layer.
   instead it should fallback to comparing the full tensor name
2025-08-26 13:57:46 -07:00
Michael Yang
85ccf7354d gptoss: enable flash attention by default (#11996) 2025-08-26 13:34:45 -07:00
Michael Yang
30fb7e19f8 remove extra field attr (#11205) 2025-08-25 09:58:16 -07:00
Jeffrey Morgan
d3450dd52e api: implement stringer for ToolFunctionParameters (#12038) 2025-08-22 16:26:48 -07:00
35 changed files with 450 additions and 326 deletions

View File

@@ -541,6 +541,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
### Mobile

View File

@@ -12,7 +12,6 @@ import (
"time"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
)
@@ -65,7 +64,7 @@ type GenerateRequest struct {
Context []int `json:"context,omitempty"`
// Stream specifies whether the response is streaming; it is true by default.
Stream types.Null[bool] `json:"stream,omitempty"`
Stream *bool `json:"stream,omitempty"`
// Raw set to true means that no formatting will be applied to the prompt.
Raw bool `json:"raw,omitempty"`
@@ -106,7 +105,7 @@ type ChatRequest struct {
Messages []Message `json:"messages"`
// Stream enables streaming of returned responses; true by default.
Stream types.Null[bool] `json:"stream,omitempty"`
Stream *bool `json:"stream,omitempty"`
// Format is the format to return the response in (e.g. "json").
Format json.RawMessage `json:"format,omitempty"`
@@ -287,16 +286,23 @@ func mapToTypeScriptType(jsonType string) string {
}
}
type ToolFunctionParameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
}
func (t *ToolFunctionParameters) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]ToolProperty `json:"properties"`
} `json:"parameters"`
Name string `json:"name"`
Description string `json:"description"`
Parameters ToolFunctionParameters `json:"parameters"`
}
func (t *ToolFunction) String() string {
@@ -382,7 +388,7 @@ type EmbedRequest struct {
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate types.Null[bool] `json:"truncate,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]any `json:"options"`
@@ -421,9 +427,9 @@ type EmbeddingResponse struct {
// CreateRequest is the request passed to [Client.Create].
type CreateRequest struct {
Model string `json:"model"`
Stream types.Null[bool] `json:"stream,omitempty"`
Quantize string `json:"quantize,omitempty"`
Model string `json:"model"`
Stream *bool `json:"stream,omitempty"`
Quantize string `json:"quantize,omitempty"`
From string `json:"from,omitempty"`
Files map[string]string `json:"files,omitempty"`
@@ -487,11 +493,11 @@ type CopyRequest struct {
// PullRequest is the request passed to [Client.Pull].
type PullRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream types.Null[bool] `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -508,11 +514,11 @@ type ProgressResponse struct {
// PushRequest is the request passed to [Client.Push].
type PushRequest struct {
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"` // Deprecated: ignored
Password string `json:"password"` // Deprecated: ignored
Stream types.Null[bool] `json:"stream,omitempty"`
Model string `json:"model"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
// Deprecated: set the model name with Model instead
Name string `json:"name"`
@@ -882,7 +888,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
if t < 0 {
d.Duration = time.Duration(math.MaxInt64)
} else {
d.Duration = time.Duration(int(t) * int(time.Second))
d.Duration = time.Duration(t * float64(time.Second))
}
case string:
d.Duration, err = time.ParseDuration(t)

View File

@@ -17,6 +17,11 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
req string
exp *Duration
}{
{
name: "Unset",
req: `{ }`,
exp: nil,
},
{
name: "Positive Integer",
req: `{ "keep_alive": 42 }`,
@@ -25,7 +30,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
{
name: "Positive Float",
req: `{ "keep_alive": 42.5 }`,
exp: &Duration{42 * time.Second},
exp: &Duration{42500 * time.Millisecond},
},
{
name: "Positive Integer String",
@@ -436,3 +441,50 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
})
}
}
func TestToolFunctionParameters_String(t *testing.T) {
tests := []struct {
name string
params ToolFunctionParameters
expected string
}{
{
name: "simple object with string property",
params: ToolFunctionParameters{
Type: "object",
Required: []string{"name"},
Properties: map[string]ToolProperty{
"name": {
Type: PropertyType{"string"},
Description: "The name of the person",
},
},
},
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
},
{
name: "marshal failure returns empty string",
params: ToolFunctionParameters{
Type: "object",
Defs: func() any {
// Create a cycle that will cause json.Marshal to fail
type selfRef struct {
Self *selfRef
}
s := &selfRef{}
s.Self = s
return s
}(),
Properties: map[string]ToolProperty{},
},
expected: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
result := test.params.String()
assert.Equal(t, test.expected, result)
})
}
}

View File

@@ -172,7 +172,20 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
blocksDims[i] = int(d)
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
bts := b.Bytes()
var tmp [16]byte
for i := 0; i < b.Len(); i += 16 {
for j := range 8 {
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[i+j], bts[i+j+8]
tmp[2*j+0] = (a & 0x0F) | (b << 4)
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[i:i+16], tmp[:])
}
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
var s bytes.Buffer
if _, err := m.scales.WriteTo(&s); err != nil {
@@ -206,5 +219,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
return 0, nil
return int64(len(u8s)), nil
}

View File

@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
const (
tensorKindFP32 uint32 = iota
tensorKindFP16
tensorKindMXFP4 = 4
tensorKindBF16 = 30
tensorKindMXFP4 = 39
)
func (t tensorBase) Kind() uint32 {

View File

@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
switch st.Kind() {
case tensorKindFP32:
return 0, binary.Write(w, binary.LittleEndian, f32s)
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
case tensorKindFP16:
f16s := make([]uint16, len(f32s))
for i := range f32s {
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
}
return 0, binary.Write(w, binary.LittleEndian, f16s)
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
case tensorKindBF16:
u8s := bfloat16.EncodeFloat32(f32s)
return 0, binary.Write(w, binary.LittleEndian, u8s)
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
default:
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
}

View File

@@ -7,9 +7,11 @@ import (
"fmt"
"io"
"log/slog"
"math"
"slices"
"strings"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/util/bufioutil"
)
@@ -275,7 +277,7 @@ type Tensor struct {
func (t Tensor) block() (n int) {
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
return -1
return math.MaxInt
}
return
@@ -288,24 +290,24 @@ func (t Tensor) blockSize() uint64 {
func (t TensorType) BlockSize() uint64 {
switch t {
case
0, // F32
1, // F16
24, // I8
25, // I16
26, // I32
27, // I64
28, // F64
30: // BF16
TensorTypeF32,
TensorTypeF16,
TensorTypeI8,
TensorTypeI16,
TensorTypeI32,
TensorTypeI64,
TensorTypeF64,
TensorTypeBF16:
return 1
case
2, // Q4_0
3, // Q4_1
4, // MXFP4
6, // Q5_0
7, // Q5_1
8, // Q8_0
9, // Q8_1
20: // IQ4_NL
TensorTypeQ4_0,
TensorTypeQ4_1,
TensorTypeQ5_0,
TensorTypeQ5_1,
TensorTypeQ8_0,
TensorTypeQ8_1,
tensorTypeIQ4_NL,
4, TensorTypeMXFP4:
return 32
default:
return 256
@@ -328,8 +330,6 @@ func (t TensorType) TypeSize() uint64 {
return 2 + blockSize/2
case TensorTypeQ4_1:
return 2 + 2 + blockSize/2
case TensorTypeMXFP4, 39:
return 1 + blockSize/2
case TensorTypeQ5_0:
return 2 + 4 + blockSize/2
case TensorTypeQ5_1:
@@ -380,6 +380,8 @@ func (t TensorType) TypeSize() uint64 {
return blockSize/8 + blockSize/16 + blockSize/32
case TensorTypeBF16:
return 2
case 4, TensorTypeMXFP4:
return 1 + blockSize/2
default:
return 0
}
@@ -479,7 +481,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
}, nil
}
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
context *= uint64(numParallel)
embedding := f.KV().EmbeddingLength()
@@ -677,7 +679,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
kv[i] *= context
}
}
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
if useFlashAttention {
// rough estimate of graph size with flash attention on
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
}
}
return
@@ -773,6 +780,13 @@ func (f GGML) SupportsFlashAttention() bool {
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
}
// FlashAttention checks if the model should enable flash attention
func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gptoss", "gpt-oss",
}, f.KV().String("general.architecture"))
}
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
func kvCacheBytesPerElement(cacheType string) float64 {
switch cacheType {

View File

@@ -533,12 +533,15 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
}
}
slices.SortStableFunc(ts, func(a, b *Tensor) int {
if i, j := a.block(), b.block(); i > 0 && j > 0 {
return cmp.Compare(i, j)
}
return cmp.Compare(a.Name, b.Name)
})
slices.SortStableFunc(
ts,
func(a, b *Tensor) int {
return cmp.Or(
cmp.Compare(a.block(), b.block()),
cmp.Compare(a.Name, b.Name),
)
},
)
var s uint64
for i := range ts {

View File

@@ -11,24 +11,24 @@ import (
)
func TestWriteGGUF(t *testing.T) {
r := rand.New(rand.NewPCG(0, 0))
b := bytes.NewBuffer(make([]byte, 2*3))
for range 8 {
t.Run("shuffle", func(t *testing.T) {
t.Parallel()
ts := []*Tensor{
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
}
r.Shuffle(len(ts), func(i, j int) {
rand.Shuffle(len(ts), func(i, j int) {
ts[i], ts[j] = ts[j], ts[i]
})
@@ -63,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
}
if diff := cmp.Diff(Tensors{
Offset: 608,
Offset: 592,
items: []*Tensor{
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},

View File

@@ -146,8 +146,6 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ4_0
case fileTypeQ4_1:
return TensorTypeQ4_1
case fileTypeMXFP4:
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
case FileTypeQ8_0:
return TensorTypeQ8_0
case fileTypeQ5_0:
@@ -176,6 +174,8 @@ func (ftype FileType) ToTensorType() TensorType {
return TensorTypeQ2_K
case FileTypeBF16:
return TensorTypeBF16
case fileTypeMXFP4:
return TensorTypeMXFP4
default:
slog.Warn("unsupported file type", "type", ftype)
return 0 // F32
@@ -191,8 +191,8 @@ const (
TensorTypeF16
TensorTypeQ4_0
TensorTypeQ4_1
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML
tensorTypeQ4_2
tensorTypeQ4_3 // unused by GGML
TensorTypeQ5_0
TensorTypeQ5_1
TensorTypeQ8_0
@@ -226,6 +226,7 @@ const (
tensorTypeIQ4_NL_4_4 // unused by GGML
tensorTypeIQ4_NL_4_8 // unused by GGML
tensorTypeIQ4_NL_8_8 // unused by GGML
TensorTypeMXFP4
)
// ParseFileType parses the provided GGUF file type
@@ -318,7 +319,7 @@ func (t TensorType) String() string {
return "F64"
case TensorTypeBF16:
return "BF16"
case TensorTypeMXFP4:
case 4, TensorTypeMXFP4:
return "MXFP4"
default:
return "unknown"

View File

@@ -0,0 +1,130 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Jesse Gross <jesse@ollama.com>
Date: Wed, 27 Aug 2025 14:39:48 -0700
Subject: [PATCH] ggml: Enable resetting backend devices
Touching a CUDA device causes the allocation of a primary context
with CUDA data structures (~300 MB of VRAM). If a device is
unused then it can be reset to free these data structures.
---
ggml/include/ggml-backend.h | 1 +
ggml/src/ggml-backend-impl.h | 4 ++++
ggml/src/ggml-backend.cpp | 8 ++++++++
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
ggml/src/ggml-cuda/vendors/hip.h | 1 +
5 files changed, 29 insertions(+), 2 deletions(-)
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index b602a7c78..fda5ceb24 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -167,6 +167,7 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
index 81749a5a3..6f10c353b 100644
--- a/ggml/src/ggml-backend-impl.h
+++ b/ggml/src/ggml-backend-impl.h
@@ -178,6 +178,10 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
+
+ // (optional) reset device, clearing existing allocations and context
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
+ void (*reset)(ggml_backend_dev_t dev);
};
struct ggml_backend_device {
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
index 05a842ed5..6556943b0 100644
--- a/ggml/src/ggml-backend.cpp
+++ b/ggml/src/ggml-backend.cpp
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params);
}
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
+ if (device->iface.reset == NULL) {
+ return;
+ }
+
+ device->iface.reset(device);
+}
+
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device);
}
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index c7f9dc3a5..e43fde523 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
return id;
}
+void ggml_cuda_reset_device(int device) {
+ ggml_cuda_set_device(device);
+ CUDA_CHECK(cudaDeviceReset());
+}
+
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
+
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
+ props->memory_total = props->memory_free = 0;
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
+ ggml_cuda_reset_device(ctx->device);
+}
+
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
+ /* .reset = */ ggml_backend_cuda_device_reset,
};
// backend reg
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
- ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index c31f31923..cf22e60d2 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -40,6 +40,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
+#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@@ -195,17 +195,19 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
slog.Warn("model missing blk.0 layer size")
}
var kvct string
if envconfig.FlashAttention() &&
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
discover.GetGPUInfo().FlashAttentionSupported() &&
f.SupportsFlashAttention() {
f.SupportsFlashAttention()
var kvct string
if useFlashAttention {
requested := strings.ToLower(envconfig.KvCacheType())
if requested != "" && f.SupportsKVCacheType(requested) {
kvct = requested
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
if len(kv) > 0 {
layerSize += kv[0]

View File

@@ -195,6 +195,11 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
// that can handle it.
fa := envconfig.FlashAttention()
if f.FlashAttention() {
slog.Info("model wants flash attention")
fa = true
}
if fa && !gpus.FlashAttentionSupported() {
slog.Warn("flash attention enabled but not supported by gpu")
fa = false

View File

@@ -535,6 +535,7 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
const BS = 17 // MXFP4 block size
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
var s uint64
var tmp [16]byte
for s < t.Size() {
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
if err := ctx.Err(); err != nil {
@@ -546,37 +547,13 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
return err
}
for j := range n / BS {
for i := 1; i < BS; i++ {
// swap nibbles
t_lo := bts[j*BS+i] & 0x0F
t_hi := bts[j*BS+i] & 0xF0
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
}
// transform aaaa...bbbb... to abababab...
oi := 0
tmp := [16]byte{}
for i := 1; i < 9; i++ {
blk_a0 := bts[j*BS+i] & 0xF0
blk_a1 := bts[j*BS+i] << 4
blk_b0 := bts[j*BS+i+8] >> 4
blk_b1 := bts[j*BS+i+8] & 0x0F
// swap once more
out0 := blk_a0 | blk_b0
out1 := blk_a1 | blk_b1
out_h0 := out0 & 0xF0
out_l0 := out0 & 0x0F
out_h1 := out1 & 0xF0
out_l1 := out1 & 0x0F
out0 = (out_h0 >> 4) | (out_l0 << 4)
out1 = (out_h1 >> 4) | (out_l1 << 4)
tmp[oi] = out0
oi++
tmp[oi] = out1
oi++
}
for i := range tmp {
bts[j*BS+i+1] = tmp[i]
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
a, b := bts[j*BS+i], bts[j*BS+i+8]
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
}
copy(bts[j*BS+1:j*BS+17], tmp[:])
}
for _, tt := range tts {
@@ -652,6 +629,18 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
})
}
// Cleanup any backend state from devices that we didn't end up using
nextDevice:
for _, d := range append(gpus, append(accels, cpus...)...) {
for _, backend := range b.schedBackends {
if d == C.ggml_backend_get_device(backend) {
continue nextDevice
}
}
C.ggml_backend_dev_reset(d)
}
if err := g.Wait(); err != nil {
return err
}

View File

@@ -167,6 +167,7 @@ extern "C" {
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);

View File

@@ -178,6 +178,10 @@ extern "C" {
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
// (optional) reset device, clearing existing allocations and context
// the caller must ensure that there are no outstanding buffers, as these will become invalid
void (*reset)(ggml_backend_dev_t dev);
};
struct ggml_backend_device {

View File

@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
return device->iface.init_backend(device, params);
}
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
if (device->iface.reset == NULL) {
return;
}
device->iface.reset(device);
}
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
return device->iface.get_buffer_type(device);
}

View File

@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
return id;
}
void ggml_cuda_reset_device(int device) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaDeviceReset());
}
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
ggml_cuda_set_device(device);
cudaError_t err;
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
props->description = ggml_backend_cuda_device_get_description(dev);
props->id = ggml_backend_cuda_device_get_id(dev);
props->type = ggml_backend_cuda_device_get_type(dev);
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
props->memory_total = props->memory_free = 0;
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
#ifdef GGML_CUDA_NO_PEER_COPY
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
}
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
ggml_cuda_reset_device(ctx->device);
}
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .get_name = */ ggml_backend_cuda_device_get_name,
/* .get_description = */ ggml_backend_cuda_device_get_description,
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
/* .event_new = */ ggml_backend_cuda_device_event_new,
/* .event_free = */ ggml_backend_cuda_device_event_free,
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
/* .reset = */ ggml_backend_cuda_device_reset,
};
// backend reg
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
dev_ctx->device = i;
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
ggml_cuda_set_device(i);
cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name;

View File

@@ -40,6 +40,7 @@
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceReset hipDeviceReset
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled

View File

@@ -18,7 +18,7 @@ type Model struct {
model.Base
model.SentencePieceModel
*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v"`
*TextModel
*MultiModalProjector `gguf:"mm"`

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
ImageProcessor
*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v"`
*Projector `gguf:"mm"`
*TextModel
}

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor

View File

@@ -17,7 +17,7 @@ type Model struct {
model.Base
model.BytePairEncoding
*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v"`
*TextModel
Projector *nn.Linear `gguf:"mm.0"`

View File

@@ -18,7 +18,7 @@ type Model struct {
model.BytePairEncoding
*TextModel
*VisionModel `gguf:"v,vision"`
*VisionModel `gguf:"v"`
ImageProcessor
}

View File

@@ -17,7 +17,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
)
@@ -572,7 +571,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
Messages: messages,
Format: format,
Options: options,
Stream: types.NullWithValue(r.Stream),
Stream: &r.Stream,
Tools: r.Tools,
Think: think,
}, nil
@@ -651,7 +650,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Model: r.Model,
Prompt: r.Prompt,
Options: options,
Stream: types.NullWithValue(r.Stream),
Stream: &r.Stream,
Suffix: r.Suffix,
}, nil
}

View File

@@ -146,7 +146,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
ch <- api.ProgressResponse{Status: "success"}
}()
if !r.Stream.Value(true) {
if r.Stream != nil && !*r.Stream {
waitForStream(c, ch)
return
}

View File

@@ -189,7 +189,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
// expire the runner
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
s.sched.expireRunner(m)
c.JSON(http.StatusOK, api.GenerateResponse{
@@ -440,7 +440,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
}()
if !req.Stream.Value(true) {
if req.Stream != nil && !*req.Stream {
var r api.GenerateResponse
var sbThinking strings.Builder
var sbContent strings.Builder
@@ -487,6 +487,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
@@ -535,7 +541,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
var count int
truncate := req.Truncate.Value(true)
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
@@ -696,7 +701,7 @@ func (s *Server) PullHandler(c *gin.Context) {
}
}()
if !req.Stream.Value(true) {
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
@@ -751,7 +756,7 @@ func (s *Server) PushHandler(c *gin.Context) {
}
}()
if !req.Stream.Value(true) {
if req.Stream != nil && !*req.Stream {
waitForStream(c, ch)
return
}
@@ -1539,7 +1544,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
// expire the runner
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
model, err := GetModel(req.Model)
if err != nil {
switch {
@@ -1770,7 +1775,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}()
if !req.Stream.Value(true) {
if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse
var toolCalls []api.ToolCall
var sbThinking strings.Builder

View File

@@ -22,6 +22,8 @@ import (
"github.com/ollama/ollama/fs/ggml"
)
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
t.Helper()
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
@@ -116,7 +118,7 @@ func TestCreateFromBin(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -146,7 +148,7 @@ func TestCreateFromModel(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -160,7 +162,7 @@ func TestCreateFromModel(t *testing.T) {
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test2",
From: "test",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -190,7 +192,7 @@ func TestCreateRemovesLayers(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .Prompt }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -211,7 +213,7 @@ func TestCreateRemovesLayers(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .System }} {{ .Prompt }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -241,7 +243,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
System: "Say hi!",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -262,7 +264,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
System: "",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -295,7 +297,7 @@ func TestCreateMergeParameters(t *testing.T) {
"top_k": 10,
"stop": []string{"USER:", "ASSISTANT:"},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -320,7 +322,7 @@ func TestCreateMergeParameters(t *testing.T) {
"temperature": 0.6,
"top_p": 0.7,
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -379,7 +381,7 @@ func TestCreateMergeParameters(t *testing.T) {
"top_p": 0.7,
"stop": []string{"<|endoftext|>"},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -439,7 +441,7 @@ func TestCreateReplacesMessages(t *testing.T) {
Content: "Oh, my god.",
},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -473,7 +475,7 @@ func TestCreateReplacesMessages(t *testing.T) {
Content: "A test. And a thumping good one at that, I'd wager.",
},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -534,7 +536,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Files: map[string]string{"test.gguf": digest},
Template: "{{ .System }} {{ .Prompt }}",
System: "Say bye!",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -576,7 +578,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ .Prompt",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
@@ -590,7 +592,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ if .Prompt }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
@@ -604,7 +606,7 @@ func TestCreateTemplateSystem(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
Template: "{{ Prompt }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
@@ -625,7 +627,7 @@ func TestCreateLicenses(t *testing.T) {
Name: "test",
Files: map[string]string{"test.gguf": digest},
License: []string{"MIT", "Apache-2.0"},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -676,7 +678,7 @@ func TestCreateDetectTemplate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -696,7 +698,7 @@ func TestCreateDetectTemplate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Name: "test",
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {

View File

@@ -12,7 +12,6 @@ import (
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/types"
)
func TestGenerateDebugRenderOnly(t *testing.T) {
@@ -54,6 +53,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
@@ -82,7 +82,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ .Prompt }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -172,7 +172,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = types.NullWithValue(stream)
req.Stream = &stream
w := createRequest(t, s.GenerateHandler, req)
if tt.expectDebug {
@@ -246,6 +246,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
@@ -274,7 +275,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -376,7 +377,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = types.NullWithValue(stream)
req.Stream = &stream
w := createRequest(t, s.ChatHandler, req)
if tt.expectDebug {

View File

@@ -126,7 +126,7 @@ func TestGenerateChat(t *testing.T) {
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
{{ end }}`,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -182,7 +182,7 @@ func TestGenerateChat(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Files: map[string]string{"bert.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -288,7 +288,7 @@ func TestGenerateChat(t *testing.T) {
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -318,7 +318,7 @@ func TestGenerateChat(t *testing.T) {
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -340,7 +340,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -363,7 +363,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -422,13 +422,15 @@ func TestGenerateChat(t *testing.T) {
EvalDuration: 1,
}
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: streamTrue,
Stream: &streamRequest,
})
if w.Code != http.StatusOK {
@@ -549,7 +551,7 @@ func TestGenerateChat(t *testing.T) {
{Role: "user", Content: "What's the weather in Seattle?"},
},
Tools: tools,
Stream: streamFalse,
Stream: &stream,
})
wg.Wait()
@@ -664,7 +666,7 @@ func TestGenerate(t *testing.T) {
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}
`,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -702,7 +704,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "bert",
Files: map[string]string{"file.gguf": digest},
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -823,7 +825,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -851,7 +853,7 @@ func TestGenerate(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -871,7 +873,7 @@ func TestGenerate(t *testing.T) {
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -893,7 +895,7 @@ func TestGenerate(t *testing.T) {
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -955,7 +957,7 @@ func TestGenerate(t *testing.T) {
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -1038,7 +1040,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
{{- if eq .Role "user" }}user: {{ .Content }}
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
{{ end }}{{ end }}<think>`,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -1064,12 +1066,13 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
}
mock.CompletionFn = nil
streamRequest := false
req := api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{
{Role: "user", Content: userContent},
},
Stream: streamFalse,
Stream: &streamRequest,
}
if think {
req.Think = &api.ThinkValue{Value: think}
@@ -1162,7 +1165,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
Think: &api.ThinkValue{Value: think},
Stream: streamFalse,
Stream: &stream,
})
wg.Wait()

View File

@@ -291,11 +291,12 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
// Create a simple test model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test-streaming",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: streamFalse,
Stream: &streamFalse,
})
if w.Code != 200 {
@@ -303,10 +304,11 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test-streaming",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: streamTrue,
Stream: &streamTrue,
Tools: getTestTools(),
})
@@ -439,11 +441,12 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
// Create model
_, digest := createHarmonyTestModel(t)
streamFalse := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "gpt-oss",
Files: map[string]string{"test.gguf": digest},
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
Stream: streamFalse,
Stream: &streamFalse,
})
if w.Code != 200 {
@@ -451,10 +454,11 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
}
// Test streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "gpt-oss",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: streamTrue,
Stream: &streamTrue,
Tools: getTestTools(),
})
@@ -621,11 +625,12 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
_, digest := createHarmonyTestModel(t)
// Create model with passthrough template
stream := false
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "harmony-test",
Files: map[string]string{"file.gguf": digest},
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
Stream: streamFalse,
Stream: &stream,
})
if w.Code != http.StatusOK {
@@ -633,10 +638,11 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
}
// Test chat endpoint with streaming
streamTrue := true
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "harmony-test",
Messages: []api.Message{{Role: "user", Content: "Hello"}},
Stream: streamTrue,
Stream: &streamTrue,
Tools: getTestTools(),
})

View File

@@ -28,16 +28,10 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
var (
streamFalse = types.NullWithValue(false)
streamTrue = types.NullWithValue(true)
)
func createTestFile(t *testing.T, name string) (string, string) {
t.Helper()
@@ -338,10 +332,11 @@ func TestRoutes(t *testing.T) {
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
_, digest := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
if err != nil {
@@ -643,7 +638,7 @@ func TestManifestCaseSensitivity(t *testing.T) {
// version.
Name: wantStableName,
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
}))
checkManifestList()
@@ -651,14 +646,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
Name: name(),
Files: map[string]string{"test.gguf": digest},
Stream: streamFalse,
Stream: &stream,
}))
checkManifestList()
t.Logf("pulling")
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
Name: name(),
Stream: streamFalse,
Stream: &stream,
Insecure: true,
}))
checkManifestList()

View File

@@ -41,13 +41,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_temperature",
Description: "Retrieve the temperature for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"city"},
Properties: map[string]api.ToolProperty{
@@ -69,13 +63,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_conditions",
Description: "Retrieve the current weather conditions for a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
@@ -105,13 +93,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "get_address",
Description: "Get the address of a given location",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"location": {
@@ -127,13 +109,7 @@ func TestParser(t *testing.T) {
Function: api.ToolFunction{
Name: "add",
Description: "Add two numbers",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Parameters: api.ToolFunctionParameters{
Type: "object",
Properties: map[string]api.ToolProperty{
"a": {

View File

@@ -1,53 +0,0 @@
package types
import (
"encoding/json"
)
// Null represents a value of any type T that may be null.
type Null[T any] struct {
value T
valid bool
}
// NullWithValue creates a new, valid Null[T].
func NullWithValue[T any](value T) Null[T] {
return Null[T]{value: value, valid: true}
}
// Value returns the value of the Type[T] if set, otherwise it returns the provided default value or the zero value of T.
func (n Null[T]) Value(defaultValue ...T) T {
if n.valid {
return n.value
}
if len(defaultValue) > 0 {
return defaultValue[0]
}
var zero T
return zero
}
// SetValue sets the value of the Type[T].
func (n *Null[T]) SetValue(t T) {
n.value = t
n.valid = true
}
// MarshalJSON implements [json.Marshaler].
func (n Null[T]) MarshalJSON() ([]byte, error) {
if n.valid {
return json.Marshal(n.value)
}
return []byte("null"), nil
}
// UnmarshalJSON implements [json.Unmarshaler].
func (n *Null[T]) UnmarshalJSON(data []byte) error {
if string(data) != "null" {
if err := json.Unmarshal(data, &n.value); err != nil {
return err
}
n.valid = true
}
return nil
}

View File

@@ -1,53 +0,0 @@
package types_test
import (
"encoding/json"
"testing"
"github.com/ollama/ollama/types"
)
func TestNull(t *testing.T) {
var s types.Null[string]
if val := s.Value(); val != "" {
t.Errorf("expected Value to return zero value '', got '%s'", val)
}
if val := s.Value("default"); val != "default" {
t.Errorf("expected Value to return default value 'default', got '%s'", val)
}
if bts, err := json.Marshal(s); err != nil {
t.Errorf("unexpected error during MarshalJSON: %v", err)
} else if want := "null"; string(bts) != want {
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
}
s.SetValue("foo")
if val := s.Value(); val != "foo" {
t.Errorf("expected Value to return 'foo', got '%s'", val)
}
s = types.NullValue("bar")
if val := s.Value(); val != "bar" {
t.Errorf("expected Value to return 'bar', got '%s'", val)
}
if bts, err := json.Marshal(s); err != nil {
t.Errorf("unexpected error during MarshalJSON: %v", err)
} else if want := `"bar"`; string(bts) != want {
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
}
if err := json.Unmarshal([]byte(`null`), &s); err != nil {
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
}
if err := json.Unmarshal([]byte(`"baz"`), &s); err != nil {
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
}
if err := json.Unmarshal([]byte(`1.2345`), &s); err == nil {
t.Error("expected error during UnmarshalJSON with invalid JSON, got nil")
}
}