Compare commits
3 Commits
mxyng/gguf
...
mxyng/type
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d05fc26570 | ||
|
|
c457628090 | ||
|
|
e914477bb6 |
@@ -541,7 +541,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
|
||||
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
|
||||
- [Neuro SAN](https://github.com/cognizant-ai-lab/neuro-san-studio) (Data-driven multi-agent orchestration framework) with [example](https://github.com/cognizant-ai-lab/neuro-san-studio/blob/main/docs/user_guide.md#ollama)
|
||||
|
||||
### Mobile
|
||||
|
||||
|
||||
60
api/types.go
60
api/types.go
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -64,7 +65,7 @@ type GenerateRequest struct {
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Stream types.Null[bool] `json:"stream,omitempty"`
|
||||
|
||||
// Raw set to true means that no formatting will be applied to the prompt.
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
@@ -105,7 +106,7 @@ type ChatRequest struct {
|
||||
Messages []Message `json:"messages"`
|
||||
|
||||
// Stream enables streaming of returned responses; true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Stream types.Null[bool] `json:"stream,omitempty"`
|
||||
|
||||
// Format is the format to return the response in (e.g. "json").
|
||||
Format json.RawMessage `json:"format,omitempty"`
|
||||
@@ -286,23 +287,16 @@ func mapToTypeScriptType(jsonType string) string {
|
||||
}
|
||||
}
|
||||
|
||||
type ToolFunctionParameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
}
|
||||
|
||||
func (t *ToolFunctionParameters) String() string {
|
||||
bts, _ := json.Marshal(t)
|
||||
return string(bts)
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters ToolFunctionParameters `json:"parameters"`
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]ToolProperty `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
|
||||
func (t *ToolFunction) String() string {
|
||||
@@ -388,7 +382,7 @@ type EmbedRequest struct {
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
Truncate types.Null[bool] `json:"truncate,omitempty"`
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]any `json:"options"`
|
||||
@@ -427,9 +421,9 @@ type EmbeddingResponse struct {
|
||||
|
||||
// CreateRequest is the request passed to [Client.Create].
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Stream types.Null[bool] `json:"stream,omitempty"`
|
||||
Quantize string `json:"quantize,omitempty"`
|
||||
|
||||
From string `json:"from,omitempty"`
|
||||
Files map[string]string `json:"files,omitempty"`
|
||||
@@ -493,11 +487,11 @@ type CopyRequest struct {
|
||||
|
||||
// PullRequest is the request passed to [Client.Pull].
|
||||
type PullRequest struct {
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"` // Deprecated: ignored
|
||||
Username string `json:"username"` // Deprecated: ignored
|
||||
Password string `json:"password"` // Deprecated: ignored
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"` // Deprecated: ignored
|
||||
Password string `json:"password"` // Deprecated: ignored
|
||||
Stream types.Null[bool] `json:"stream,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -514,11 +508,11 @@ type ProgressResponse struct {
|
||||
|
||||
// PushRequest is the request passed to [Client.Push].
|
||||
type PushRequest struct {
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Insecure bool `json:"insecure,omitempty"`
|
||||
Username string `json:"username"` // Deprecated: ignored
|
||||
Password string `json:"password"` // Deprecated: ignored
|
||||
Stream types.Null[bool] `json:"stream,omitempty"`
|
||||
|
||||
// Deprecated: set the model name with Model instead
|
||||
Name string `json:"name"`
|
||||
@@ -888,7 +882,7 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
|
||||
if t < 0 {
|
||||
d.Duration = time.Duration(math.MaxInt64)
|
||||
} else {
|
||||
d.Duration = time.Duration(t * float64(time.Second))
|
||||
d.Duration = time.Duration(int(t) * int(time.Second))
|
||||
}
|
||||
case string:
|
||||
d.Duration, err = time.ParseDuration(t)
|
||||
|
||||
@@ -17,11 +17,6 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
req string
|
||||
exp *Duration
|
||||
}{
|
||||
{
|
||||
name: "Unset",
|
||||
req: `{ }`,
|
||||
exp: nil,
|
||||
},
|
||||
{
|
||||
name: "Positive Integer",
|
||||
req: `{ "keep_alive": 42 }`,
|
||||
@@ -30,7 +25,7 @@ func TestKeepAliveParsingFromJSON(t *testing.T) {
|
||||
{
|
||||
name: "Positive Float",
|
||||
req: `{ "keep_alive": 42.5 }`,
|
||||
exp: &Duration{42500 * time.Millisecond},
|
||||
exp: &Duration{42 * time.Second},
|
||||
},
|
||||
{
|
||||
name: "Positive Integer String",
|
||||
@@ -441,50 +436,3 @@ func TestThinking_UnmarshalJSON(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunctionParameters_String(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
params ToolFunctionParameters
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "simple object with string property",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"name"},
|
||||
Properties: map[string]ToolProperty{
|
||||
"name": {
|
||||
Type: PropertyType{"string"},
|
||||
Description: "The name of the person",
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `{"type":"object","required":["name"],"properties":{"name":{"type":"string","description":"The name of the person"}}}`,
|
||||
},
|
||||
{
|
||||
name: "marshal failure returns empty string",
|
||||
params: ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Defs: func() any {
|
||||
// Create a cycle that will cause json.Marshal to fail
|
||||
type selfRef struct {
|
||||
Self *selfRef
|
||||
}
|
||||
s := &selfRef{}
|
||||
s.Self = s
|
||||
return s
|
||||
}(),
|
||||
Properties: map[string]ToolProperty{},
|
||||
},
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
result := test.params.String()
|
||||
assert.Equal(t, test.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,20 +172,7 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
blocksDims[i] = int(d)
|
||||
}
|
||||
|
||||
bts := b.Bytes()
|
||||
var tmp [16]byte
|
||||
for i := 0; i < b.Len(); i += 16 {
|
||||
for j := range 8 {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[i+j], bts[i+j+8]
|
||||
tmp[2*j+0] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*j+1] = (a >> 4) | (b & 0xF0)
|
||||
}
|
||||
|
||||
copy(bts[i:i+16], tmp[:])
|
||||
}
|
||||
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(bts))
|
||||
var blocks tensor.Tensor = tensor.New(tensor.WithShape(blocksDims...), tensor.WithBacking(b.Bytes()))
|
||||
|
||||
var s bytes.Buffer
|
||||
if _, err := m.scales.WriteTo(&s); err != nil {
|
||||
@@ -219,5 +206,5 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return int64(len(u8s)), nil
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
@@ -33,8 +33,8 @@ func (t tensorBase) Shape() []uint64 {
|
||||
const (
|
||||
tensorKindFP32 uint32 = iota
|
||||
tensorKindFP16
|
||||
tensorKindMXFP4 = 4
|
||||
tensorKindBF16 = 30
|
||||
tensorKindMXFP4 = 39
|
||||
)
|
||||
|
||||
func (t tensorBase) Kind() uint32 {
|
||||
|
||||
@@ -188,17 +188,17 @@ func (st safetensor) WriteTo(w io.Writer) (int64, error) {
|
||||
|
||||
switch st.Kind() {
|
||||
case tensorKindFP32:
|
||||
return int64(len(f32s) * 4), binary.Write(w, binary.LittleEndian, f32s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, f32s)
|
||||
case tensorKindFP16:
|
||||
f16s := make([]uint16, len(f32s))
|
||||
for i := range f32s {
|
||||
f16s[i] = float16.Fromfloat32(f32s[i]).Bits()
|
||||
}
|
||||
|
||||
return int64(len(f16s) * 2), binary.Write(w, binary.LittleEndian, f16s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, f16s)
|
||||
case tensorKindBF16:
|
||||
u8s := bfloat16.EncodeFloat32(f32s)
|
||||
return int64(len(u8s)), binary.Write(w, binary.LittleEndian, u8s)
|
||||
return 0, binary.Write(w, binary.LittleEndian, u8s)
|
||||
default:
|
||||
return 0, fmt.Errorf("unknown storage type: %d", st.Kind())
|
||||
}
|
||||
|
||||
@@ -7,11 +7,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/util/bufioutil"
|
||||
)
|
||||
|
||||
@@ -277,7 +275,7 @@ type Tensor struct {
|
||||
|
||||
func (t Tensor) block() (n int) {
|
||||
if _, err := fmt.Sscanf(t.Name, "blk.%d.", &n); err != nil {
|
||||
return math.MaxInt
|
||||
return -1
|
||||
}
|
||||
|
||||
return
|
||||
@@ -290,24 +288,24 @@ func (t Tensor) blockSize() uint64 {
|
||||
func (t TensorType) BlockSize() uint64 {
|
||||
switch t {
|
||||
case
|
||||
TensorTypeF32,
|
||||
TensorTypeF16,
|
||||
TensorTypeI8,
|
||||
TensorTypeI16,
|
||||
TensorTypeI32,
|
||||
TensorTypeI64,
|
||||
TensorTypeF64,
|
||||
TensorTypeBF16:
|
||||
0, // F32
|
||||
1, // F16
|
||||
24, // I8
|
||||
25, // I16
|
||||
26, // I32
|
||||
27, // I64
|
||||
28, // F64
|
||||
30: // BF16
|
||||
return 1
|
||||
case
|
||||
TensorTypeQ4_0,
|
||||
TensorTypeQ4_1,
|
||||
TensorTypeQ5_0,
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
2, // Q4_0
|
||||
3, // Q4_1
|
||||
4, // MXFP4
|
||||
6, // Q5_0
|
||||
7, // Q5_1
|
||||
8, // Q8_0
|
||||
9, // Q8_1
|
||||
20: // IQ4_NL
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -330,6 +328,8 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return 2 + blockSize/2
|
||||
case TensorTypeQ4_1:
|
||||
return 2 + 2 + blockSize/2
|
||||
case TensorTypeMXFP4, 39:
|
||||
return 1 + blockSize/2
|
||||
case TensorTypeQ5_0:
|
||||
return 2 + 4 + blockSize/2
|
||||
case TensorTypeQ5_1:
|
||||
@@ -380,8 +380,6 @@ func (t TensorType) TypeSize() uint64 {
|
||||
return blockSize/8 + blockSize/16 + blockSize/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + blockSize/2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -481,7 +479,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string, useFlashAttention bool) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
context *= uint64(numParallel)
|
||||
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
@@ -679,12 +677,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
kv[i] *= context
|
||||
}
|
||||
}
|
||||
|
||||
partialOffload = 2 * f.KV().HeadCountMax() / cmp.Or(f.KV().HeadCountKVMin(), 1) * kvTotal / 6
|
||||
if useFlashAttention {
|
||||
// rough estimate of graph size with flash attention on
|
||||
partialOffload = (4*uint64(numParallel) + context>>10 + 110) * format.MebiByte
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
@@ -780,13 +773,6 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// FlashAttention checks if the model should enable flash attention
|
||||
func (f GGML) FlashAttention() bool {
|
||||
return slices.Contains([]string{
|
||||
"gptoss", "gpt-oss",
|
||||
}, f.KV().String("general.architecture"))
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
switch cacheType {
|
||||
|
||||
@@ -533,15 +533,12 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
}
|
||||
}
|
||||
|
||||
slices.SortStableFunc(
|
||||
ts,
|
||||
func(a, b *Tensor) int {
|
||||
return cmp.Or(
|
||||
cmp.Compare(a.block(), b.block()),
|
||||
cmp.Compare(a.Name, b.Name),
|
||||
)
|
||||
},
|
||||
)
|
||||
slices.SortStableFunc(ts, func(a, b *Tensor) int {
|
||||
if i, j := a.block(), b.block(); i > 0 && j > 0 {
|
||||
return cmp.Compare(i, j)
|
||||
}
|
||||
return cmp.Compare(a.Name, b.Name)
|
||||
})
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
|
||||
@@ -4,8 +4,6 @@ import (
|
||||
"bytes"
|
||||
"math/rand/v2"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
@@ -13,24 +11,24 @@ import (
|
||||
)
|
||||
|
||||
func TestWriteGGUF(t *testing.T) {
|
||||
b := bytes.NewBuffer(make([]byte, 2*3))
|
||||
r := rand.New(rand.NewPCG(0, 0))
|
||||
for range 8 {
|
||||
t.Run("shuffle", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ts := []*Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_up.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.2.ffn_norm.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.1.ffn_down.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{2, 3}, WriterTo: b},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: b},
|
||||
{Name: "token_embd.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.1.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.2.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.3.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.4.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "blk.5.attn_norm.weight", Shape: []uint64{2, 3}, WriterTo: bytes.NewBuffer(make([]byte, 2*3))},
|
||||
{Name: "output_norm.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
{Name: "output.weight", Shape: []uint64{3, 2}, WriterTo: bytes.NewBuffer(make([]byte, 3*2))},
|
||||
}
|
||||
|
||||
rand.Shuffle(len(ts), func(i, j int) {
|
||||
r.Shuffle(len(ts), func(i, j int) {
|
||||
ts[i], ts[j] = ts[j], ts[i]
|
||||
})
|
||||
|
||||
@@ -65,14 +63,14 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 592,
|
||||
Offset: 608,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.ffn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_down.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.ffn_up.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.ffn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.1.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.2.attn_norm.weight", Offset: 64, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.3.attn_norm.weight", Offset: 96, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.4.attn_norm.weight", Offset: 128, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.5.attn_norm.weight", Offset: 160, Shape: []uint64{2, 3}},
|
||||
{Name: "output.weight", Offset: 192, Shape: []uint64{3, 2}},
|
||||
{Name: "output_norm.weight", Offset: 224, Shape: []uint64{3, 2}},
|
||||
{Name: "token_embd.weight", Offset: 256, Shape: []uint64{2, 3}},
|
||||
@@ -83,47 +81,3 @@ func TestWriteGGUF(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadArray(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
create := func(tb testing.TB, kv KV) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(b.TempDir(), "")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := WriteGGUF(f, kv, nil); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
cases := map[string]any{
|
||||
"int32": slices.Repeat([]int32{42}, 1_000_000),
|
||||
"uint32": slices.Repeat([]uint32{42}, 1_000_000),
|
||||
"float32": slices.Repeat([]float32{42.}, 1_000_000),
|
||||
"string": slices.Repeat([]string{"42"}, 1_000_000),
|
||||
}
|
||||
|
||||
for name, bb := range cases {
|
||||
for _, maxArraySize := range []int{-1, 0, 1024} {
|
||||
b.Run(name+"-maxArraySize="+strconv.Itoa(maxArraySize), func(b *testing.B) {
|
||||
p := create(b, KV{"array": bb})
|
||||
for b.Loop() {
|
||||
f, err := os.Open(p)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
if _, err := Decode(f, maxArraySize); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
f.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -146,6 +146,8 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ4_0
|
||||
case fileTypeQ4_1:
|
||||
return TensorTypeQ4_1
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
case FileTypeQ8_0:
|
||||
return TensorTypeQ8_0
|
||||
case fileTypeQ5_0:
|
||||
@@ -174,8 +176,6 @@ func (ftype FileType) ToTensorType() TensorType {
|
||||
return TensorTypeQ2_K
|
||||
case FileTypeBF16:
|
||||
return TensorTypeBF16
|
||||
case fileTypeMXFP4:
|
||||
return TensorTypeMXFP4
|
||||
default:
|
||||
slog.Warn("unsupported file type", "type", ftype)
|
||||
return 0 // F32
|
||||
@@ -191,8 +191,8 @@ const (
|
||||
TensorTypeF16
|
||||
TensorTypeQ4_0
|
||||
TensorTypeQ4_1
|
||||
tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeMXFP4 // Formerly unused tensorTypeQ4_2
|
||||
tensorTypeQ4_3 // unused by GGML
|
||||
TensorTypeQ5_0
|
||||
TensorTypeQ5_1
|
||||
TensorTypeQ8_0
|
||||
@@ -226,7 +226,6 @@ const (
|
||||
tensorTypeIQ4_NL_4_4 // unused by GGML
|
||||
tensorTypeIQ4_NL_4_8 // unused by GGML
|
||||
tensorTypeIQ4_NL_8_8 // unused by GGML
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
// ParseFileType parses the provided GGUF file type
|
||||
@@ -319,7 +318,7 @@ func (t TensorType) String() string {
|
||||
return "F64"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
case 4, TensorTypeMXFP4:
|
||||
case TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
default:
|
||||
return "unknown"
|
||||
|
||||
114
fs/gguf/gguf.go
114
fs/gguf/gguf.go
@@ -35,10 +35,9 @@ type File struct {
|
||||
Magic [4]byte
|
||||
Version uint32
|
||||
|
||||
keyValues *lazy[KeyValue]
|
||||
tensorInfos *lazy[TensorInfo]
|
||||
offset int64
|
||||
n uint64
|
||||
keyValues *lazy[KeyValue]
|
||||
tensors *lazy[TensorInfo]
|
||||
offset int64
|
||||
|
||||
file *os.File
|
||||
reader *bufferedReader
|
||||
@@ -70,12 +69,12 @@ func Open(path string) (f *File, err error) {
|
||||
return nil, fmt.Errorf("%w version %v", ErrUnsupported, f.Version)
|
||||
}
|
||||
|
||||
f.tensorInfos, err = newLazy(f, f.readTensor)
|
||||
f.tensors, err = newLazy(f, f.readTensor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f.tensorInfos.successFunc = func() error {
|
||||
f.tensors.successFunc = func() error {
|
||||
offset := f.reader.offset
|
||||
|
||||
alignment := cmp.Or(f.KeyValue("general.alignment").Int(), 32)
|
||||
@@ -120,15 +119,12 @@ func (f *File) readTensor() (TensorInfo, error) {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
tensorInfo := TensorInfo{
|
||||
return TensorInfo{
|
||||
Name: name,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Type: TensorType(type_),
|
||||
}
|
||||
|
||||
f.n += tensorInfo.NumValues()
|
||||
return tensorInfo, nil
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (f *File) readKeyValue() (KeyValue, error) {
|
||||
@@ -190,20 +186,20 @@ func read[T any](f *File) (t T, err error) {
|
||||
}
|
||||
|
||||
func readString(f *File) (string, error) {
|
||||
bts := f.bts[:8]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
n, err := read[uint64](f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
n := binary.LittleEndian.Uint64(bts)
|
||||
if int(n) > len(f.bts) {
|
||||
f.bts = make([]byte, n)
|
||||
}
|
||||
|
||||
bts = f.bts[:n]
|
||||
bts := f.bts[:n]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer clear(bts)
|
||||
|
||||
return string(bts), nil
|
||||
}
|
||||
@@ -249,70 +245,37 @@ func readArray(f *File) (any, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func readArrayData[T any](f *File, n uint64) (*lazy[T], error) {
|
||||
offset := f.reader.offset
|
||||
func readArrayData[T any](f *File, n uint64) (s []T, err error) {
|
||||
s = make([]T, n)
|
||||
for i := range n {
|
||||
e, err := read[T](f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var t T
|
||||
if _, err := f.reader.Discard(int(n) * binary.Size(t)); err != nil {
|
||||
return nil, err
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(f.file, offset, int64(int(n)*binary.Size(t)))
|
||||
next, stop := iter.Pull(func(yield func(T) bool) {
|
||||
s := make([]T, n)
|
||||
if err := binary.Read(sr, binary.LittleEndian, &s); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, e := range s {
|
||||
if !yield(e) {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return &lazy[T]{count: n, next: next, stop: stop}, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func readArrayString(f *File, n uint64) (*lazy[string], error) {
|
||||
offset := f.reader.offset
|
||||
|
||||
var size int64
|
||||
for range n {
|
||||
bts := f.bts[:8]
|
||||
if _, err := io.ReadFull(f.reader, bts); err != nil {
|
||||
func readArrayString(f *File, n uint64) (s []string, err error) {
|
||||
s = make([]string, n)
|
||||
for i := range n {
|
||||
e, err := readString(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
n := int(binary.LittleEndian.Uint64(bts))
|
||||
if _, err := f.reader.Discard(n); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
size += 8 + int64(n)
|
||||
s[i] = e
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(f.file, offset, size)
|
||||
next, stop := iter.Pull(func(yield func(string) bool) {
|
||||
f := File{reader: newBufferedReader(sr, 16<<10), bts: make([]byte, 4096)}
|
||||
for range n {
|
||||
s, err := readString(&f)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if !yield(s) {
|
||||
return
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return &lazy[string]{count: n, next: next, stop: stop}, nil
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (f *File) Close() error {
|
||||
f.keyValues.stop()
|
||||
f.tensorInfos.stop()
|
||||
f.tensors.stop()
|
||||
return f.file.Close()
|
||||
}
|
||||
|
||||
@@ -345,15 +308,15 @@ func (f *File) KeyValues() iter.Seq2[int, KeyValue] {
|
||||
}
|
||||
|
||||
func (f *File) TensorInfo(name string) TensorInfo {
|
||||
if index := slices.IndexFunc(f.tensorInfos.values, func(t TensorInfo) bool {
|
||||
if index := slices.IndexFunc(f.tensors.values, func(t TensorInfo) bool {
|
||||
return t.Name == name
|
||||
}); index >= 0 {
|
||||
return f.tensorInfos.values[index]
|
||||
return f.tensors.values[index]
|
||||
}
|
||||
|
||||
// fast-forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
for tensor, ok := f.tensorInfos.next(); ok; tensor, ok = f.tensorInfos.next() {
|
||||
for tensor, ok := f.tensors.next(); ok; tensor, ok = f.tensors.next() {
|
||||
if tensor.Name == name {
|
||||
return tensor
|
||||
}
|
||||
@@ -363,13 +326,13 @@ func (f *File) TensorInfo(name string) TensorInfo {
|
||||
}
|
||||
|
||||
func (f *File) NumTensors() int {
|
||||
return int(f.tensorInfos.count)
|
||||
return int(f.tensors.count)
|
||||
}
|
||||
|
||||
func (f *File) TensorInfos() iter.Seq2[int, TensorInfo] {
|
||||
// fast forward through key values if we haven't already
|
||||
_ = f.keyValues.rest()
|
||||
return f.tensorInfos.All()
|
||||
f.keyValues.rest()
|
||||
return f.tensors.All()
|
||||
}
|
||||
|
||||
func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
@@ -379,11 +342,6 @@ func (f *File) TensorReader(name string) (TensorInfo, io.Reader, error) {
|
||||
}
|
||||
|
||||
// fast forward through tensor info if we haven't already
|
||||
_ = f.tensorInfos.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), int64(t.NumBytes())), nil
|
||||
}
|
||||
|
||||
func (f *File) NumValues() uint64 {
|
||||
_ = f.tensorInfos.rest()
|
||||
return f.n
|
||||
_ = f.tensors.rest()
|
||||
return t, io.NewSectionReader(f.file, f.offset+int64(t.Offset), t.NumBytes()), nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package gguf_test
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -248,43 +247,3 @@ func BenchmarkRead(b *testing.B) {
|
||||
f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadArray(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
create := func(tb testing.TB, kv ggml.KV) string {
|
||||
tb.Helper()
|
||||
f, err := os.CreateTemp(b.TempDir(), "")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, nil); err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
cases := map[string]any{
|
||||
"int32": slices.Repeat([]int32{42}, 1_000_000),
|
||||
"uint32": slices.Repeat([]uint32{42}, 1_000_000),
|
||||
"float32": slices.Repeat([]float32{42.}, 1_000_000),
|
||||
"string": slices.Repeat([]string{"42"}, 1_000_000),
|
||||
}
|
||||
|
||||
for name, bb := range cases {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
p := create(b, ggml.KV{"array": bb})
|
||||
for b.Loop() {
|
||||
f, err := gguf.Open(p)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
_ = f.KeyValue("array")
|
||||
f.Close()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,6 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"slices"
|
||||
)
|
||||
@@ -14,15 +11,32 @@ type KeyValue struct {
|
||||
}
|
||||
|
||||
func (kv KeyValue) Valid() bool {
|
||||
return kv.Key != "" && kv.value != nil
|
||||
return kv.Key != "" && kv.Value.value != nil
|
||||
}
|
||||
|
||||
type Value struct {
|
||||
value any
|
||||
}
|
||||
|
||||
func (v Value) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(v.value)
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Int returns Value as a signed integer. If it is not a signed integer, it returns 0.
|
||||
@@ -74,44 +88,3 @@ func (v Value) String() string {
|
||||
func (v Value) Strings() (strings []string) {
|
||||
return values[string](v, reflect.String)
|
||||
}
|
||||
|
||||
func value[T any](v Value, kinds ...reflect.Kind) (t T) {
|
||||
vv := reflect.ValueOf(v.value)
|
||||
if slices.Contains(kinds, vv.Kind()) {
|
||||
t = vv.Convert(reflect.TypeOf(t)).Interface().(T)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func values[T any](v Value, kinds ...reflect.Kind) (ts []T) {
|
||||
switch vv := reflect.ValueOf(v.value); vv.Kind() {
|
||||
case reflect.Ptr:
|
||||
out := vv.MethodByName("Values").Call(nil)
|
||||
if len(out) > 0 && out[0].IsValid() {
|
||||
next, stop := iter.Pull(out[0].Seq())
|
||||
defer stop()
|
||||
|
||||
ts = make([]T, vv.Elem().FieldByName("count").Uint())
|
||||
for i := range ts {
|
||||
t, ok := next()
|
||||
if !ok {
|
||||
slog.Error("error reading value", "index", i)
|
||||
return nil
|
||||
}
|
||||
|
||||
ts[i] = t.Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
|
||||
return ts
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
if slices.Contains(kinds, vv.Type().Elem().Kind()) {
|
||||
ts = make([]T, vv.Len())
|
||||
for i := range vv.Len() {
|
||||
ts[i] = vv.Index(i).Convert(reflect.TypeOf(ts[i])).Interface().(T)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -21,9 +21,3 @@ func (rs *bufferedReader) Read(p []byte) (n int, err error) {
|
||||
rs.offset += int64(n)
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (rs *bufferedReader) Discard(n int) (discarded int, err error) {
|
||||
discarded, err = rs.Reader.Discard(n)
|
||||
rs.offset += int64(discarded)
|
||||
return discarded, err
|
||||
}
|
||||
|
||||
@@ -16,17 +16,17 @@ func (ti TensorInfo) Valid() bool {
|
||||
return ti.Name != "" && ti.NumBytes() > 0
|
||||
}
|
||||
|
||||
func (ti TensorInfo) NumValues() uint64 {
|
||||
var numItems uint64 = 1
|
||||
func (ti TensorInfo) NumValues() int64 {
|
||||
var numItems int64 = 1
|
||||
for _, dim := range ti.Shape {
|
||||
numItems *= dim
|
||||
numItems *= int64(dim)
|
||||
}
|
||||
return numItems
|
||||
}
|
||||
|
||||
// NumBytes returns the number of bytes in the tensor.
|
||||
func (ti TensorInfo) NumBytes() uint64 {
|
||||
return uint64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||
func (ti TensorInfo) NumBytes() int64 {
|
||||
return int64(float64(ti.NumValues()) * ti.Type.NumBytes())
|
||||
}
|
||||
|
||||
func (ti TensorInfo) LogValue() slog.Value {
|
||||
@@ -34,8 +34,8 @@ func (ti TensorInfo) LogValue() slog.Value {
|
||||
slog.String("name", ti.Name),
|
||||
slog.Int64("offset", int64(ti.Offset)),
|
||||
slog.Any("shape", ti.Shape),
|
||||
slog.Uint64("num_values", ti.NumValues()),
|
||||
slog.Uint64("num_bytes", ti.NumBytes()),
|
||||
slog.Int64("num_values", ti.NumValues()),
|
||||
slog.Int64("num_bytes", ti.NumBytes()),
|
||||
slog.Any("type", ti.Type),
|
||||
)
|
||||
}
|
||||
@@ -97,8 +97,6 @@ const (
|
||||
tensorTypeIQ4_NL_4_4
|
||||
tensorTypeIQ4_NL_4_8
|
||||
tensorTypeIQ4_NL_8_8
|
||||
|
||||
TensorTypeMXFP4
|
||||
)
|
||||
|
||||
func (tt TensorType) NumBytes() float64 {
|
||||
@@ -165,8 +163,6 @@ func (tt TensorType) typeSize() int64 {
|
||||
return tt.blockSize()/8 + tt.blockSize()/16 + tt.blockSize()/32
|
||||
case TensorTypeBF16:
|
||||
return 2
|
||||
case 4, TensorTypeMXFP4:
|
||||
return 1 + tt.blockSize() / 2
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
@@ -189,8 +185,7 @@ func (tt TensorType) blockSize() int64 {
|
||||
TensorTypeQ5_1,
|
||||
TensorTypeQ8_0,
|
||||
TensorTypeQ8_1,
|
||||
tensorTypeIQ4_NL,
|
||||
4, TensorTypeMXFP4:
|
||||
tensorTypeIQ4_NL:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
@@ -200,85 +195,83 @@ func (tt TensorType) blockSize() int64 {
|
||||
func (tt TensorType) String() string {
|
||||
switch tt {
|
||||
case TensorTypeF32:
|
||||
return "F32"
|
||||
return "f32"
|
||||
case TensorTypeF16:
|
||||
return "F16"
|
||||
return "f16"
|
||||
case TensorTypeQ4_0:
|
||||
return "Q4_0"
|
||||
return "q4_0"
|
||||
case TensorTypeQ4_1:
|
||||
return "Q4_1"
|
||||
// case tensorTypeQ4_2:
|
||||
// return "Q4_2"
|
||||
return "q4_1"
|
||||
case tensorTypeQ4_2:
|
||||
return "q4_2"
|
||||
case tensorTypeQ4_3:
|
||||
return "Q4_3"
|
||||
return "q4_3"
|
||||
case TensorTypeQ5_0:
|
||||
return "Q5_0"
|
||||
return "q5_0"
|
||||
case TensorTypeQ5_1:
|
||||
return "Q5_1"
|
||||
return "q5_1"
|
||||
case TensorTypeQ8_0:
|
||||
return "Q8_0"
|
||||
return "q8_0"
|
||||
case TensorTypeQ8_1:
|
||||
return "Q8_1"
|
||||
return "q8_1"
|
||||
case TensorTypeQ2_K:
|
||||
return "Q2_K"
|
||||
return "q2_k"
|
||||
case TensorTypeQ3_K:
|
||||
return "Q3_K"
|
||||
return "q3_k"
|
||||
case TensorTypeQ4_K:
|
||||
return "Q4_K"
|
||||
return "q4_k"
|
||||
case TensorTypeQ5_K:
|
||||
return "Q5_K"
|
||||
return "q5_k"
|
||||
case TensorTypeQ6_K:
|
||||
return "Q6_K"
|
||||
return "q6_k"
|
||||
case TensorTypeQ8_K:
|
||||
return "Q8_K"
|
||||
return "q8_k"
|
||||
case tensorTypeIQ2_XXS:
|
||||
return "IQ2_XXS"
|
||||
return "iq2_xxs"
|
||||
case tensorTypeIQ2_XS:
|
||||
return "IQ2_XS"
|
||||
return "iq2_xs"
|
||||
case tensorTypeIQ3_XXS:
|
||||
return "IQ3_XXS"
|
||||
return "iq3_xxs"
|
||||
case tensorTypeIQ1_S:
|
||||
return "IQ1_S"
|
||||
return "iq1_s"
|
||||
case tensorTypeIQ4_NL:
|
||||
return "IQ4_NL"
|
||||
return "iq4_nl"
|
||||
case tensorTypeIQ3_S:
|
||||
return "IQ3_S"
|
||||
return "iq3_s"
|
||||
case tensorTypeIQ2_S:
|
||||
return "IQ2_S"
|
||||
return "iq2_s"
|
||||
case tensorTypeIQ4_XS:
|
||||
return "IQ4_XS"
|
||||
return "iq4_xs"
|
||||
case TensorTypeI8:
|
||||
return "I8"
|
||||
return "i8"
|
||||
case TensorTypeI16:
|
||||
return "I16"
|
||||
return "i16"
|
||||
case TensorTypeI32:
|
||||
return "I32"
|
||||
return "i32"
|
||||
case TensorTypeI64:
|
||||
return "I64"
|
||||
return "i64"
|
||||
case TensorTypeF64:
|
||||
return "F64"
|
||||
return "f64"
|
||||
case tensorTypeIQ1_M:
|
||||
return "IQ1_M"
|
||||
return "iq1_m"
|
||||
case TensorTypeBF16:
|
||||
return "BF16"
|
||||
return "bf16"
|
||||
case tensorTypeQ4_0_4_4:
|
||||
return "Q4_0_4_4"
|
||||
return "q4_0_4_4"
|
||||
case tensorTypeQ4_0_4_8:
|
||||
return "Q4_0_4_8"
|
||||
return "q4_0_4_8"
|
||||
case tensorTypeQ4_0_8_8:
|
||||
return "Q4_0_8_8"
|
||||
return "q4_0_8_8"
|
||||
case tensorTypeTQ1_0:
|
||||
return "TQ1_0"
|
||||
return "tq1_0"
|
||||
case tensorTypeTQ2_0:
|
||||
return "TQ2_0"
|
||||
return "tq2_0"
|
||||
case tensorTypeIQ4_NL_4_4:
|
||||
return "IQ4_NL_4_4"
|
||||
return "iq4_nl_4_4"
|
||||
case tensorTypeIQ4_NL_4_8:
|
||||
return "IQ4_NL_4_8"
|
||||
return "iq4_nl_4_8"
|
||||
case tensorTypeIQ4_NL_8_8:
|
||||
return "IQ4_NL_8_8"
|
||||
case 4, TensorTypeMXFP4:
|
||||
return "MXFP4"
|
||||
return "iq4_nl_8_8"
|
||||
default:
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
@@ -1,130 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Jesse Gross <jesse@ollama.com>
|
||||
Date: Wed, 27 Aug 2025 14:39:48 -0700
|
||||
Subject: [PATCH] ggml: Enable resetting backend devices
|
||||
|
||||
Touching a CUDA device causes the allocation of a primary context
|
||||
with CUDA data structures (~300 MB of VRAM). If a device is
|
||||
unused then it can be reset to free these data structures.
|
||||
---
|
||||
ggml/include/ggml-backend.h | 1 +
|
||||
ggml/src/ggml-backend-impl.h | 4 ++++
|
||||
ggml/src/ggml-backend.cpp | 8 ++++++++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++--
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 1 +
|
||||
5 files changed, 29 insertions(+), 2 deletions(-)
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index b602a7c78..fda5ceb24 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -167,6 +167,7 @@ extern "C" {
|
||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
+ GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h
|
||||
index 81749a5a3..6f10c353b 100644
|
||||
--- a/ggml/src/ggml-backend-impl.h
|
||||
+++ b/ggml/src/ggml-backend-impl.h
|
||||
@@ -178,6 +178,10 @@ extern "C" {
|
||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
+
|
||||
+ // (optional) reset device, clearing existing allocations and context
|
||||
+ // the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
+ void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp
|
||||
index 05a842ed5..6556943b0 100644
|
||||
--- a/ggml/src/ggml-backend.cpp
|
||||
+++ b/ggml/src/ggml-backend.cpp
|
||||
@@ -477,6 +477,14 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
+void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
+ if (device->iface.reset == NULL) {
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ device->iface.reset(device);
|
||||
+}
|
||||
+
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index c7f9dc3a5..e43fde523 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -103,6 +103,11 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
+void ggml_cuda_reset_device(int device) {
|
||||
+ ggml_cuda_set_device(device);
|
||||
+ CUDA_CHECK(cudaDeviceReset());
|
||||
+}
|
||||
+
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
@@ -3243,7 +3248,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
- ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
+
|
||||
+ // Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
+ // If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
+ props->memory_total = props->memory_free = 0;
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
@@ -3700,6 +3708,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
+static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
+ ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
+ ggml_cuda_reset_device(ctx->device);
|
||||
+}
|
||||
+
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
@@ -3716,6 +3729,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
+ /* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
@@ -3835,7 +3849,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
- ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index c31f31923..cf22e60d2 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -40,6 +40,7 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
+#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
@@ -195,19 +195,17 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
}
|
||||
|
||||
useFlashAttention := (envconfig.FlashAttention() || f.FlashAttention()) &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
if useFlashAttention {
|
||||
if envconfig.FlashAttention() &&
|
||||
discover.GetGPUInfo().FlashAttentionSupported() &&
|
||||
f.SupportsFlashAttention() {
|
||||
requested := strings.ToLower(envconfig.KvCacheType())
|
||||
if requested != "" && f.SupportsKVCacheType(requested) {
|
||||
kvct = requested
|
||||
}
|
||||
}
|
||||
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct, useFlashAttention)
|
||||
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
|
||||
|
||||
if len(kv) > 0 {
|
||||
layerSize += kv[0]
|
||||
|
||||
@@ -195,11 +195,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||
// that can handle it.
|
||||
fa := envconfig.FlashAttention()
|
||||
if f.FlashAttention() {
|
||||
slog.Info("model wants flash attention")
|
||||
fa = true
|
||||
}
|
||||
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
|
||||
@@ -535,7 +535,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
const BS = 17 // MXFP4 block size
|
||||
bts := make([]byte, 8*BS*format.KibiByte) // ~128k block aligned
|
||||
var s uint64
|
||||
var tmp [16]byte
|
||||
for s < t.Size() {
|
||||
// Stop if either the parent context has been canceled or if any of the other tensors returned an error
|
||||
if err := ctx.Err(); err != nil {
|
||||
@@ -547,13 +546,37 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
return err
|
||||
}
|
||||
for j := range n / BS {
|
||||
for i := 1; i < 9; i++ {
|
||||
// transform a1b2c3 ... x7y8z9 -> 71xa82yb93zc
|
||||
a, b := bts[j*BS+i], bts[j*BS+i+8]
|
||||
tmp[2*(i-1)] = (a & 0x0F) | (b << 4)
|
||||
tmp[2*(i-1)+1] = (a >> 4) | (b & 0xF0)
|
||||
for i := 1; i < BS; i++ {
|
||||
// swap nibbles
|
||||
t_lo := bts[j*BS+i] & 0x0F
|
||||
t_hi := bts[j*BS+i] & 0xF0
|
||||
bts[j*BS+i] = (t_lo << 4) | (t_hi >> 4)
|
||||
}
|
||||
// transform aaaa...bbbb... to abababab...
|
||||
oi := 0
|
||||
tmp := [16]byte{}
|
||||
for i := 1; i < 9; i++ {
|
||||
blk_a0 := bts[j*BS+i] & 0xF0
|
||||
blk_a1 := bts[j*BS+i] << 4
|
||||
blk_b0 := bts[j*BS+i+8] >> 4
|
||||
blk_b1 := bts[j*BS+i+8] & 0x0F
|
||||
// swap once more
|
||||
out0 := blk_a0 | blk_b0
|
||||
out1 := blk_a1 | blk_b1
|
||||
out_h0 := out0 & 0xF0
|
||||
out_l0 := out0 & 0x0F
|
||||
out_h1 := out1 & 0xF0
|
||||
out_l1 := out1 & 0x0F
|
||||
out0 = (out_h0 >> 4) | (out_l0 << 4)
|
||||
out1 = (out_h1 >> 4) | (out_l1 << 4)
|
||||
tmp[oi] = out0
|
||||
oi++
|
||||
tmp[oi] = out1
|
||||
oi++
|
||||
}
|
||||
for i := range tmp {
|
||||
bts[j*BS+i+1] = tmp[i]
|
||||
}
|
||||
copy(bts[j*BS+1:j*BS+17], tmp[:])
|
||||
}
|
||||
|
||||
for _, tt := range tts {
|
||||
@@ -629,18 +652,6 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error {
|
||||
})
|
||||
}
|
||||
|
||||
// Cleanup any backend state from devices that we didn't end up using
|
||||
nextDevice:
|
||||
for _, d := range append(gpus, append(accels, cpus...)...) {
|
||||
for _, backend := range b.schedBackends {
|
||||
if d == C.ggml_backend_get_device(backend) {
|
||||
continue nextDevice
|
||||
}
|
||||
}
|
||||
|
||||
C.ggml_backend_dev_reset(d)
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
1
ml/backend/ggml/ggml/include/ggml-backend.h
vendored
@@ -167,7 +167,6 @@ extern "C" {
|
||||
GGML_API void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_dev_props * props);
|
||||
GGML_API ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params);
|
||||
GGML_API void ggml_backend_dev_reset(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device);
|
||||
GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size);
|
||||
|
||||
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
4
ml/backend/ggml/ggml/src/ggml-backend-impl.h
vendored
@@ -178,10 +178,6 @@ extern "C" {
|
||||
ggml_backend_event_t (*event_new) (ggml_backend_dev_t dev);
|
||||
void (*event_free) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
void (*event_synchronize) (ggml_backend_dev_t dev, ggml_backend_event_t event);
|
||||
|
||||
// (optional) reset device, clearing existing allocations and context
|
||||
// the caller must ensure that there are no outstanding buffers, as these will become invalid
|
||||
void (*reset)(ggml_backend_dev_t dev);
|
||||
};
|
||||
|
||||
struct ggml_backend_device {
|
||||
|
||||
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
8
ml/backend/ggml/ggml/src/ggml-backend.cpp
vendored
@@ -477,14 +477,6 @@ ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * par
|
||||
return device->iface.init_backend(device, params);
|
||||
}
|
||||
|
||||
void ggml_backend_dev_reset(ggml_backend_dev_t device) {
|
||||
if (device->iface.reset == NULL) {
|
||||
return;
|
||||
}
|
||||
|
||||
device->iface.reset(device);
|
||||
}
|
||||
|
||||
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
|
||||
return device->iface.get_buffer_type(device);
|
||||
}
|
||||
|
||||
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
17
ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu
vendored
@@ -103,11 +103,6 @@ int ggml_cuda_get_device() {
|
||||
return id;
|
||||
}
|
||||
|
||||
void ggml_cuda_reset_device(int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
CUDA_CHECK(cudaDeviceReset());
|
||||
}
|
||||
|
||||
static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) {
|
||||
ggml_cuda_set_device(device);
|
||||
cudaError_t err;
|
||||
@@ -3248,10 +3243,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back
|
||||
props->description = ggml_backend_cuda_device_get_description(dev);
|
||||
props->id = ggml_backend_cuda_device_get_id(dev);
|
||||
props->type = ggml_backend_cuda_device_get_type(dev);
|
||||
|
||||
// Memory reporting is disabled to avoid allocation of a CUDA primary context (~300 MB per device).
|
||||
// If you need the memory data, call ggml_backend_dev_memory() explicitly.
|
||||
props->memory_total = props->memory_free = 0;
|
||||
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
|
||||
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
@@ -3708,11 +3700,6 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g
|
||||
CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context));
|
||||
}
|
||||
|
||||
static void ggml_backend_cuda_device_reset(ggml_backend_dev_t dev) {
|
||||
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
|
||||
ggml_cuda_reset_device(ctx->device);
|
||||
}
|
||||
|
||||
static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .get_name = */ ggml_backend_cuda_device_get_name,
|
||||
/* .get_description = */ ggml_backend_cuda_device_get_description,
|
||||
@@ -3729,7 +3716,6 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = {
|
||||
/* .event_new = */ ggml_backend_cuda_device_event_new,
|
||||
/* .event_free = */ ggml_backend_cuda_device_event_free,
|
||||
/* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize,
|
||||
/* .reset = */ ggml_backend_cuda_device_reset,
|
||||
};
|
||||
|
||||
// backend reg
|
||||
@@ -3849,6 +3835,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
dev_ctx->device = i;
|
||||
dev_ctx->name = GGML_CUDA_NAME + std::to_string(i);
|
||||
|
||||
ggml_cuda_set_device(i);
|
||||
cudaDeviceProp prop;
|
||||
CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
|
||||
dev_ctx->description = prop.name;
|
||||
|
||||
@@ -40,7 +40,6 @@
|
||||
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
|
||||
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
|
||||
#define cudaDeviceProp hipDeviceProp_t
|
||||
#define cudaDeviceReset hipDeviceReset
|
||||
#define cudaDeviceSynchronize hipDeviceSynchronize
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.Base
|
||||
model.SentencePieceModel
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
ImageProcessor
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*Projector `gguf:"mm"`
|
||||
*TextModel
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*MultiModalProjector `gguf:"mm"`
|
||||
|
||||
ImageProcessor
|
||||
|
||||
@@ -17,7 +17,7 @@ type Model struct {
|
||||
model.Base
|
||||
model.BytePairEncoding
|
||||
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
*TextModel
|
||||
|
||||
Projector *nn.Linear `gguf:"mm.0"`
|
||||
|
||||
@@ -18,7 +18,7 @@ type Model struct {
|
||||
model.BytePairEncoding
|
||||
|
||||
*TextModel
|
||||
*VisionModel `gguf:"v"`
|
||||
*VisionModel `gguf:"v,vision"`
|
||||
|
||||
ImageProcessor
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -571,7 +572,7 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
Messages: messages,
|
||||
Format: format,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Stream: types.NullWithValue(r.Stream),
|
||||
Tools: r.Tools,
|
||||
Think: think,
|
||||
}, nil
|
||||
@@ -650,7 +651,7 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
Model: r.Model,
|
||||
Prompt: r.Prompt,
|
||||
Options: options,
|
||||
Stream: &r.Stream,
|
||||
Stream: types.NullWithValue(r.Stream),
|
||||
Suffix: r.Suffix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
ch <- api.ProgressResponse{Status: "success"}
|
||||
}()
|
||||
|
||||
if r.Stream != nil && !*r.Stream {
|
||||
if !r.Stream.Value(true) {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ import (
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/gguf"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/harmony"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
@@ -189,7 +189,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if req.Prompt == "" && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
if req.Prompt == "" && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
s.sched.expireRunner(m)
|
||||
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
@@ -440,7 +440,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
if !req.Stream.Value(true) {
|
||||
var r api.GenerateResponse
|
||||
var sbThinking strings.Builder
|
||||
var sbContent strings.Builder
|
||||
@@ -487,12 +487,6 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
truncate := true
|
||||
|
||||
if req.Truncate != nil && !*req.Truncate {
|
||||
truncate = false
|
||||
}
|
||||
|
||||
var input []string
|
||||
|
||||
switch i := req.Input.(type) {
|
||||
@@ -534,14 +528,14 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
kvData, _, err := getModelData(m.ModelPath, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var count int
|
||||
truncate := req.Truncate.Value(true)
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
@@ -549,7 +543,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(f.KeyValue("context_length").Int()))
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
|
||||
@@ -702,7 +696,7 @@ func (s *Server) PullHandler(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
if !req.Stream.Value(true) {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
@@ -757,7 +751,7 @@ func (s *Server) PushHandler(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
if !req.Stream.Value(true) {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
@@ -952,63 +946,53 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
|
||||
fmt.Fprint(&sb, m.String())
|
||||
resp.Modelfile = sb.String()
|
||||
|
||||
f, err := gguf.Open(m.ModelPath)
|
||||
kvData, tensors, err := getModelData(m.ModelPath, req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
resp.ModelInfo = make(map[string]any, f.NumKeyValues())
|
||||
for _, keyValue := range f.KeyValues() {
|
||||
if !slices.Contains([]string{"general.name", "tokenizer.chat_template"}, keyValue.Key) {
|
||||
resp.ModelInfo[keyValue.Key] = keyValue.Value
|
||||
}
|
||||
}
|
||||
delete(kvData, "general.name")
|
||||
delete(kvData, "tokenizer.chat_template")
|
||||
resp.ModelInfo = kvData
|
||||
|
||||
resp.Tensors = make([]api.Tensor, f.NumTensors())
|
||||
for i, tensorInfo := range f.TensorInfos() {
|
||||
resp.Tensors[i] = api.Tensor{
|
||||
Name: tensorInfo.Name,
|
||||
Type: tensorInfo.Type.String(),
|
||||
Shape: tensorInfo.Shape,
|
||||
}
|
||||
tensorData := make([]api.Tensor, len(tensors.Items()))
|
||||
for cnt, t := range tensors.Items() {
|
||||
tensorData[cnt] = api.Tensor{Name: t.Name, Type: t.Type(), Shape: t.Shape}
|
||||
}
|
||||
resp.ModelInfo["general.parameter_count"] = f.NumValues()
|
||||
resp.Tensors = tensorData
|
||||
|
||||
if len(m.ProjectorPaths) > 0 {
|
||||
f, err := gguf.Open(m.ProjectorPaths[0])
|
||||
projectorData, _, err := getModelData(m.ProjectorPaths[0], req.Verbose)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
resp.ProjectorInfo = make(map[string]any, f.NumKeyValues())
|
||||
for _, keyValue := range f.KeyValues() {
|
||||
resp.ProjectorInfo[keyValue.Key] = keyValue.Value
|
||||
}
|
||||
resp.ProjectorInfo = projectorData
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func getModelData(digest string, verbose bool) ([]gguf.KeyValue, []gguf.TensorInfo, error) {
|
||||
f, err := gguf.Open(digest)
|
||||
func getModelData(digest string, verbose bool) (ggml.KV, ggml.Tensors, error) {
|
||||
maxArraySize := 0
|
||||
if verbose {
|
||||
maxArraySize = -1
|
||||
}
|
||||
data, err := llm.LoadModel(digest, maxArraySize)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
keyValues := make([]gguf.KeyValue, f.NumKeyValues())
|
||||
for i, keyValue := range f.KeyValues() {
|
||||
keyValues[i] = keyValue
|
||||
return nil, ggml.Tensors{}, err
|
||||
}
|
||||
|
||||
tensorInfos := make([]gguf.TensorInfo, f.NumTensors())
|
||||
for i, info := range f.TensorInfos() {
|
||||
tensorInfos[i] = info
|
||||
kv := data.KV()
|
||||
|
||||
if !verbose {
|
||||
for k := range kv {
|
||||
if t, ok := kv[k].([]any); len(t) > 5 && ok {
|
||||
kv[k] = []any{}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keyValues, tensorInfos, nil
|
||||
return kv, data.Tensors(), nil
|
||||
}
|
||||
|
||||
func (s *Server) ListHandler(c *gin.Context) {
|
||||
@@ -1555,7 +1539,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// expire the runner
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && req.KeepAlive.Duration == 0 {
|
||||
if len(req.Messages) == 0 && req.KeepAlive != nil && int(req.KeepAlive.Seconds()) == 0 {
|
||||
model, err := GetModel(req.Model)
|
||||
if err != nil {
|
||||
switch {
|
||||
@@ -1786,7 +1770,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}()
|
||||
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
if !req.Stream.Value(true) {
|
||||
var resp api.ChatResponse
|
||||
var toolCalls []api.ToolCall
|
||||
var sbThinking strings.Builder
|
||||
|
||||
@@ -22,8 +22,6 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
var stream bool = false
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string, string) {
|
||||
t.Helper()
|
||||
t.Setenv("OLLAMA_MODELS", cmp.Or(os.Getenv("OLLAMA_MODELS"), t.TempDir()))
|
||||
@@ -118,7 +116,7 @@ func TestCreateFromBin(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -148,7 +146,7 @@ func TestCreateFromModel(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -162,7 +160,7 @@ func TestCreateFromModel(t *testing.T) {
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
From: "test",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -192,7 +190,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -213,7 +211,7 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ .System }} {{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -243,7 +241,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
System: "Say hi!",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -264,7 +262,7 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
System: "",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -297,7 +295,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
"top_k": 10,
|
||||
"stop": []string{"USER:", "ASSISTANT:"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -322,7 +320,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
"temperature": 0.6,
|
||||
"top_p": 0.7,
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -381,7 +379,7 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
"top_p": 0.7,
|
||||
"stop": []string{"<|endoftext|>"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -441,7 +439,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
Content: "Oh, my god.",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -475,7 +473,7 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
Content: "A test. And a thumping good one at that, I'd wager.",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -536,7 +534,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ .System }} {{ .Prompt }}",
|
||||
System: "Say bye!",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -578,7 +576,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ .Prompt",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -592,7 +590,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ if .Prompt }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -606,7 +604,7 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: "{{ Prompt }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -627,7 +625,7 @@ func TestCreateLicenses(t *testing.T) {
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
License: []string{"MIT", "Apache-2.0"},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -678,7 +676,7 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -698,7 +696,7 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/types"
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
@@ -53,7 +54,6 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
@@ -82,7 +82,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -172,7 +172,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
req.Stream = types.NullWithValue(stream)
|
||||
w := createRequest(t, s.GenerateHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
@@ -246,7 +246,6 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a test model
|
||||
stream := false
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "llama",
|
||||
"llama.block_count": uint32(1),
|
||||
@@ -275,7 +274,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
Model: "test-model",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -377,7 +376,7 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
}
|
||||
t.Run(tt.name+streamSuffix, func(t *testing.T) {
|
||||
req := tt.request
|
||||
req.Stream = &stream
|
||||
req.Stream = types.NullWithValue(stream)
|
||||
w := createRequest(t, s.ChatHandler, req)
|
||||
|
||||
if tt.expectDebug {
|
||||
|
||||
@@ -126,7 +126,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}
|
||||
{{ end }}`,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -182,7 +182,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Files: map[string]string{"bert.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -288,7 +288,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -318,7 +318,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -340,7 +340,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "system", Content: "You can perform magic tricks."},
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -363,7 +363,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "system", Content: "You can perform magic tricks."},
|
||||
{Role: "user", Content: "Help me write tests."},
|
||||
},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -422,15 +422,13 @@ func TestGenerateChat(t *testing.T) {
|
||||
EvalDuration: 1,
|
||||
}
|
||||
|
||||
streamRequest := true
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-system",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &streamRequest,
|
||||
Stream: streamTrue,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -551,7 +549,7 @@ func TestGenerateChat(t *testing.T) {
|
||||
{Role: "user", Content: "What's the weather in Seattle?"},
|
||||
},
|
||||
Tools: tools,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
@@ -666,7 +664,7 @@ func TestGenerate(t *testing.T) {
|
||||
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}Assistant: {{ .Response }} {{ end }}
|
||||
`,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -704,7 +702,7 @@ func TestGenerate(t *testing.T) {
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "bert",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -825,7 +823,7 @@ func TestGenerate(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -853,7 +851,7 @@ func TestGenerate(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-system",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -873,7 +871,7 @@ func TestGenerate(t *testing.T) {
|
||||
Model: "test-system",
|
||||
Prompt: "Hello!",
|
||||
System: "You can perform magic tricks.",
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -895,7 +893,7 @@ func TestGenerate(t *testing.T) {
|
||||
Template: `{{- if .System }}{{ .System }} {{ end }}
|
||||
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
|
||||
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -957,7 +955,7 @@ func TestGenerate(t *testing.T) {
|
||||
Model: "test-system",
|
||||
Prompt: "Help me write tests.",
|
||||
Raw: true,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -1040,7 +1038,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
{{- if eq .Role "user" }}user: {{ .Content }}
|
||||
{{ else if eq .Role "assistant" }}assistant: {{ if .Thinking }}<think>{{ .Thinking }}</think>{{ end }}{{ .Content }}
|
||||
{{ end }}{{ end }}<think>`,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -1066,13 +1064,12 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
}
|
||||
mock.CompletionFn = nil
|
||||
|
||||
streamRequest := false
|
||||
req := api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: userContent},
|
||||
},
|
||||
Stream: &streamRequest,
|
||||
Stream: streamFalse,
|
||||
}
|
||||
if think {
|
||||
req.Think = &api.ThinkValue{Value: think}
|
||||
@@ -1165,7 +1162,7 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Analyze this complex problem"}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
|
||||
@@ -291,12 +291,11 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
// Create a simple test model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
@@ -304,11 +303,10 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test-streaming",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Stream: streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
@@ -441,12 +439,11 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
|
||||
// Create model
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
streamFalse := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "gpt-oss",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ .Tools }}{{ .Prompt }}`,
|
||||
Stream: &streamFalse,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != 200 {
|
||||
@@ -454,11 +451,10 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "gpt-oss",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Stream: streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
@@ -625,12 +621,11 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
_, digest := createHarmonyTestModel(t)
|
||||
|
||||
// Create model with passthrough template
|
||||
stream := false
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "harmony-test",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Template: `<|start|><|end|>{{ with .Tools }}{{ end }}{{ .Prompt }}`,
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -638,11 +633,10 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
}
|
||||
|
||||
// Test chat endpoint with streaming
|
||||
streamTrue := true
|
||||
w = createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "harmony-test",
|
||||
Messages: []api.Message{{Role: "user", Content: "Hello"}},
|
||||
Stream: &streamTrue,
|
||||
Stream: streamTrue,
|
||||
Tools: getTestTools(),
|
||||
})
|
||||
|
||||
|
||||
@@ -28,10 +28,16 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
var (
|
||||
streamFalse = types.NullWithValue(false)
|
||||
streamTrue = types.NullWithValue(true)
|
||||
)
|
||||
|
||||
func createTestFile(t *testing.T, name string) (string, string) {
|
||||
t.Helper()
|
||||
|
||||
@@ -332,11 +338,10 @@ func TestRoutes(t *testing.T) {
|
||||
Path: "/api/create",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
_, digest := createTestFile(t, "ollama-model")
|
||||
stream := false
|
||||
createReq := api.CreateRequest{
|
||||
Name: "t-bone",
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
}
|
||||
jsonData, err := json.Marshal(createReq)
|
||||
if err != nil {
|
||||
@@ -638,7 +643,7 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
// version.
|
||||
Name: wantStableName,
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
@@ -646,14 +651,14 @@ func TestManifestCaseSensitivity(t *testing.T) {
|
||||
checkOK(createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: name(),
|
||||
Files: map[string]string{"test.gguf": digest},
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
t.Logf("pulling")
|
||||
checkOK(createRequest(t, s.PullHandler, api.PullRequest{
|
||||
Name: name(),
|
||||
Stream: &stream,
|
||||
Stream: streamFalse,
|
||||
Insecure: true,
|
||||
}))
|
||||
checkManifestList()
|
||||
|
||||
@@ -41,7 +41,13 @@ func TestParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_temperature",
|
||||
Description: "Retrieve the temperature for a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"city"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
@@ -63,7 +69,13 @@ func TestParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_conditions",
|
||||
Description: "Retrieve the current weather conditions for a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
@@ -93,7 +105,13 @@ func TestParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_address",
|
||||
Description: "Get the address of a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
@@ -109,7 +127,13 @@ func TestParser(t *testing.T) {
|
||||
Function: api.ToolFunction{
|
||||
Name: "add",
|
||||
Description: "Add two numbers",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"a": {
|
||||
|
||||
53
types/null.go
Normal file
53
types/null.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
// Null represents a value of any type T that may be null.
|
||||
type Null[T any] struct {
|
||||
value T
|
||||
valid bool
|
||||
}
|
||||
|
||||
// NullWithValue creates a new, valid Null[T].
|
||||
func NullWithValue[T any](value T) Null[T] {
|
||||
return Null[T]{value: value, valid: true}
|
||||
}
|
||||
|
||||
// Value returns the value of the Type[T] if set, otherwise it returns the provided default value or the zero value of T.
|
||||
func (n Null[T]) Value(defaultValue ...T) T {
|
||||
if n.valid {
|
||||
return n.value
|
||||
}
|
||||
if len(defaultValue) > 0 {
|
||||
return defaultValue[0]
|
||||
}
|
||||
var zero T
|
||||
return zero
|
||||
}
|
||||
|
||||
// SetValue sets the value of the Type[T].
|
||||
func (n *Null[T]) SetValue(t T) {
|
||||
n.value = t
|
||||
n.valid = true
|
||||
}
|
||||
|
||||
// MarshalJSON implements [json.Marshaler].
|
||||
func (n Null[T]) MarshalJSON() ([]byte, error) {
|
||||
if n.valid {
|
||||
return json.Marshal(n.value)
|
||||
}
|
||||
return []byte("null"), nil
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements [json.Unmarshaler].
|
||||
func (n *Null[T]) UnmarshalJSON(data []byte) error {
|
||||
if string(data) != "null" {
|
||||
if err := json.Unmarshal(data, &n.value); err != nil {
|
||||
return err
|
||||
}
|
||||
n.valid = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
53
types/null_test.go
Normal file
53
types/null_test.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package types_test
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/types"
|
||||
)
|
||||
|
||||
func TestNull(t *testing.T) {
|
||||
var s types.Null[string]
|
||||
if val := s.Value(); val != "" {
|
||||
t.Errorf("expected Value to return zero value '', got '%s'", val)
|
||||
}
|
||||
|
||||
if val := s.Value("default"); val != "default" {
|
||||
t.Errorf("expected Value to return default value 'default', got '%s'", val)
|
||||
}
|
||||
|
||||
if bts, err := json.Marshal(s); err != nil {
|
||||
t.Errorf("unexpected error during MarshalJSON: %v", err)
|
||||
} else if want := "null"; string(bts) != want {
|
||||
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
|
||||
}
|
||||
|
||||
s.SetValue("foo")
|
||||
if val := s.Value(); val != "foo" {
|
||||
t.Errorf("expected Value to return 'foo', got '%s'", val)
|
||||
}
|
||||
|
||||
s = types.NullValue("bar")
|
||||
if val := s.Value(); val != "bar" {
|
||||
t.Errorf("expected Value to return 'bar', got '%s'", val)
|
||||
}
|
||||
|
||||
if bts, err := json.Marshal(s); err != nil {
|
||||
t.Errorf("unexpected error during MarshalJSON: %v", err)
|
||||
} else if want := `"bar"`; string(bts) != want {
|
||||
t.Errorf("expected marshaled JSON to be %s, got %s", want, string(bts))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(`null`), &s); err != nil {
|
||||
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(`"baz"`), &s); err != nil {
|
||||
t.Errorf("unexpected error during UnmarshalJSON: %v", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(`1.2345`), &s); err == nil {
|
||||
t.Error("expected error during UnmarshalJSON with invalid JSON, got nil")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user