Compare commits
2 Commits
drifkin/ar
...
parth/serv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bc9d42e2e | ||
|
|
4053c489b4 |
@@ -51,7 +51,7 @@ see if the change were accepted.
|
||||
|
||||
The title should look like:
|
||||
|
||||
<package>: <short description>
|
||||
<package>: <short description>
|
||||
|
||||
The package is the most affected Go package. If the change does not affect Go
|
||||
code, then use the directory name instead. Changes to a single well-known
|
||||
|
||||
@@ -104,8 +104,8 @@ COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
FROM --platform=linux/arm64 scratch AS arm64
|
||||
COPY --from=cuda-11 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_v11
|
||||
COPY --from=cuda-12 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_v12
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 /lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 /lib/ollama/cuda_jetpack6
|
||||
COPY --from=jetpack-5 dist/lib/ollama/cuda_v11 lib/ollama/cuda_jetpack5
|
||||
COPY --from=jetpack-6 dist/lib/ollama/cuda_v12 lib/ollama/cuda_jetpack6
|
||||
|
||||
FROM scratch AS rocm
|
||||
COPY --from=rocm-6 dist/lib/ollama/rocm /lib/ollama/rocm
|
||||
|
||||
@@ -291,7 +291,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
|
||||
- [Minimalistic React UI for Ollama Models](https://github.com/richawo/minimal-llm-ui)
|
||||
- [Ollamac](https://github.com/kevinhermawan/Ollamac)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI)
|
||||
- [big-AGI](https://github.com/enricoros/big-AGI/blob/main/docs/config-local-ollama.md)
|
||||
- [Cheshire Cat assistant framework](https://github.com/cheshire-cat-ai/core)
|
||||
- [Amica](https://github.com/semperai/amica)
|
||||
- [chatd](https://github.com/BruceMacD/chatd)
|
||||
@@ -348,7 +348,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
|
||||
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot and Ollama4j
|
||||
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
|
||||
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
|
||||
- [Claude Dev](https://github.com/saoudrizwan/claude-dev) - VSCode extension for multi-file/whole-repo coding
|
||||
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
|
||||
@@ -440,7 +440,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
|
||||
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
|
||||
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
|
||||
- [GGUF-to-Ollama](https://github.com/jonathanhecl/gguf-to-ollama) - Importing GGUF to Ollama made easy (multiplatform)
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
52
api/types.go
52
api/types.go
@@ -163,65 +163,19 @@ func (t *ToolCallFunctionArguments) String() string {
|
||||
|
||||
type Tool struct {
|
||||
Type string `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Function ToolFunction `json:"function"`
|
||||
}
|
||||
|
||||
// PropertyType can be either a string or an array of strings
|
||||
type PropertyType []string
|
||||
|
||||
// UnmarshalJSON implements the json.Unmarshaler interface
|
||||
func (pt *PropertyType) UnmarshalJSON(data []byte) error {
|
||||
// Try to unmarshal as a string first
|
||||
var s string
|
||||
if err := json.Unmarshal(data, &s); err == nil {
|
||||
*pt = []string{s}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If that fails, try to unmarshal as an array of strings
|
||||
var a []string
|
||||
if err := json.Unmarshal(data, &a); err != nil {
|
||||
return err
|
||||
}
|
||||
*pt = a
|
||||
return nil
|
||||
}
|
||||
|
||||
// MarshalJSON implements the json.Marshaler interface
|
||||
func (pt PropertyType) MarshalJSON() ([]byte, error) {
|
||||
if len(pt) == 1 {
|
||||
// If there's only one type, marshal as a string
|
||||
return json.Marshal(pt[0])
|
||||
}
|
||||
// Otherwise marshal as an array
|
||||
return json.Marshal([]string(pt))
|
||||
}
|
||||
|
||||
// String returns a string representation of the PropertyType
|
||||
func (pt PropertyType) String() string {
|
||||
if len(pt) == 0 {
|
||||
return ""
|
||||
}
|
||||
if len(pt) == 1 {
|
||||
return pt[0]
|
||||
}
|
||||
return fmt.Sprintf("%v", []string(pt))
|
||||
}
|
||||
|
||||
type ToolFunction struct {
|
||||
Name string `json:"name"`
|
||||
Description string `json:"description"`
|
||||
Parameters struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
} `json:"parameters"`
|
||||
}
|
||||
|
||||
@@ -231,144 +231,3 @@ func TestMessage_UnmarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolFunction_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
wantErr string
|
||||
}{
|
||||
{
|
||||
name: "valid enum with same types",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": ["a", "b", "c"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
{
|
||||
name: "empty enum array",
|
||||
input: `{
|
||||
"name": "test",
|
||||
"description": "test function",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["test"],
|
||||
"properties": {
|
||||
"test": {
|
||||
"type": "string",
|
||||
"description": "test prop",
|
||||
"enum": []
|
||||
}
|
||||
}
|
||||
}
|
||||
}`,
|
||||
wantErr: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var tf ToolFunction
|
||||
err := json.Unmarshal([]byte(tt.input), &tf)
|
||||
|
||||
if tt.wantErr != "" {
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), tt.wantErr)
|
||||
} else {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_UnmarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected PropertyType
|
||||
}{
|
||||
{
|
||||
name: "string type",
|
||||
input: `"string"`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
{
|
||||
name: "array of types",
|
||||
input: `["string", "number"]`,
|
||||
expected: PropertyType{"string", "number"},
|
||||
},
|
||||
{
|
||||
name: "array with single type",
|
||||
input: `["string"]`,
|
||||
expected: PropertyType{"string"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var pt PropertyType
|
||||
if err := json.Unmarshal([]byte(test.input), &pt); err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(pt) != len(test.expected) {
|
||||
t.Errorf("Length mismatch: got %v, expected %v", len(pt), len(test.expected))
|
||||
}
|
||||
|
||||
for i, v := range pt {
|
||||
if v != test.expected[i] {
|
||||
t.Errorf("Value mismatch at index %d: got %v, expected %v", i, v, test.expected[i])
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPropertyType_MarshalJSON(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input PropertyType
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "single type",
|
||||
input: PropertyType{"string"},
|
||||
expected: `"string"`,
|
||||
},
|
||||
{
|
||||
name: "multiple types",
|
||||
input: PropertyType{"string", "number"},
|
||||
expected: `["string","number"]`,
|
||||
},
|
||||
{
|
||||
name: "empty type",
|
||||
input: PropertyType{},
|
||||
expected: `[]`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
data, err := json.Marshal(test.input)
|
||||
if err != nil {
|
||||
t.Errorf("Unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if string(data) != test.expected {
|
||||
t.Errorf("Marshaled data mismatch: got %v, expected %v", string(data), test.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1381,6 +1381,7 @@ func NewCLI() *cobra.Command {
|
||||
envVars["OLLAMA_NOPRUNE"],
|
||||
envVars["OLLAMA_ORIGINS"],
|
||||
envVars["OLLAMA_SCHED_SPREAD"],
|
||||
envVars["OLLAMA_TMPDIR"],
|
||||
envVars["OLLAMA_FLASH_ATTENTION"],
|
||||
envVars["OLLAMA_KV_CACHE_TYPE"],
|
||||
envVars["OLLAMA_LLM_LIBRARY"],
|
||||
|
||||
@@ -26,6 +26,7 @@ When you run Ollama on **Windows**, there are a few different locations. You can
|
||||
- `explorer %LOCALAPPDATA%\Ollama` to view logs. The most recent server logs will be in `server.log` and older logs will be in `server-#.log`
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` to browse the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` to browse where models and configuration is stored
|
||||
- `explorer %TEMP%` where temporary executable files are stored in one or more `ollama*` directories
|
||||
|
||||
To enable additional debug logging to help troubleshoot problems, first **Quit the running app from the tray menu** then in a powershell terminal
|
||||
|
||||
@@ -68,6 +69,10 @@ If you run into problems on Linux and want to install an older version, or you'd
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.5.7 sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its temporary executable files, you can specify an alternate location by setting OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
|
||||
## Linux docker
|
||||
|
||||
If Ollama initially works on the GPU in a docker container, but then switches to running on CPU after some period of time with errors in the server log reporting GPU discovery failures, this can be resolved by disabling systemd cgroup management in Docker. Edit `/etc/docker/daemon.json` on the host and add `"exec-opts": ["native.cgroupdriver=cgroupfs"]` to the docker configuration.
|
||||
|
||||
@@ -62,6 +62,7 @@ the explorer window by hitting `<Ctrl>+R` and type in:
|
||||
- *upgrade.log* contains log output for upgrades
|
||||
- `explorer %LOCALAPPDATA%\Programs\Ollama` contains the binaries (The installer adds this to your user PATH)
|
||||
- `explorer %HOMEPATH%\.ollama` contains models and configuration
|
||||
- `explorer %TEMP%` contains temporary executable files in one or more `ollama*` directories
|
||||
|
||||
## Uninstall
|
||||
|
||||
|
||||
278
fs/ggml/ggml.go
278
fs/ggml/ggml.go
@@ -6,7 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
@@ -53,80 +52,32 @@ func (kv KV) EmbeddingLength() uint64 {
|
||||
return uint64(kv.Uint("embedding_length"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCounts() []uint64 {
|
||||
return kv.UintOrArrayAsArray("attention.head_count", kv.BlockCount(), 1)
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count"))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKVs() []uint64 {
|
||||
return kv.UintOrArrayAsArray("attention.head_count_kv", kv.BlockCount(), 1)
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return uint64(kv.Uint("attention.head_count_kv", 1))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCount() []uint64 {
|
||||
headCount := kv.HeadCounts()
|
||||
embeddingHeadCount := make([]uint64, len(headCount))
|
||||
for i, heads := range headCount {
|
||||
if heads == 0 {
|
||||
embeddingHeadCount[i] = 0
|
||||
} else {
|
||||
embeddingHeadCount[i] = kv.EmbeddingLength() / heads
|
||||
}
|
||||
func (kv KV) EmbeddingHeadCount() uint64 {
|
||||
if heads := kv.HeadCount(); heads > 0 {
|
||||
return kv.EmbeddingLength() / heads
|
||||
}
|
||||
|
||||
return embeddingHeadCount
|
||||
return 0
|
||||
}
|
||||
|
||||
func (kv KV) FillArrayOrDefault(key string, defaultValue []uint64) []uint64 {
|
||||
length := len(defaultValue)
|
||||
if v, ok := keyValueUntyped(kv, key); ok {
|
||||
switch v := v.(type) {
|
||||
case uint32:
|
||||
return FillArray(uint64(v), length)
|
||||
case uint64:
|
||||
return FillArray(v, length)
|
||||
case int32:
|
||||
return FillArray(uint64(v), length)
|
||||
default:
|
||||
slog.Warn("unsupported type", "key", key, "type", reflect.TypeOf(v))
|
||||
}
|
||||
}
|
||||
|
||||
return defaultValue
|
||||
func (kv KV) EmbeddingHeadCountK() uint64 {
|
||||
return uint64(kv.Uint("attention.key_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountK() []uint64 {
|
||||
return kv.FillArrayOrDefault("attention.key_length", kv.EmbeddingHeadCount())
|
||||
func (kv KV) EmbeddingHeadCountV() uint64 {
|
||||
return uint64(kv.Uint("attention.value_length", uint32(kv.EmbeddingHeadCount())))
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingHeadCountV() []uint64 {
|
||||
return kv.FillArrayOrDefault("attention.value_length", kv.EmbeddingHeadCount())
|
||||
}
|
||||
|
||||
func (kv KV) GQAMax() uint64 {
|
||||
heads := kv.HeadCounts()
|
||||
headsKV := kv.HeadCountKVs()
|
||||
if len(heads) != len(headsKV) {
|
||||
slog.Warn("head count and head count kv are not the same length")
|
||||
return 0
|
||||
}
|
||||
if len(heads) == 0 {
|
||||
slog.Warn("head count is empty")
|
||||
return 0
|
||||
}
|
||||
|
||||
maxGQA := uint64(0)
|
||||
for i := range heads {
|
||||
head := heads[i]
|
||||
headKV := headsKV[i]
|
||||
if head == 0 || headKV == 0 {
|
||||
return 0
|
||||
}
|
||||
gqa := head / headKV
|
||||
if gqa > maxGQA {
|
||||
maxGQA = gqa
|
||||
}
|
||||
}
|
||||
|
||||
return maxGQA
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
@@ -153,41 +104,6 @@ func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
}
|
||||
|
||||
func (kv KV) UintOrArrayAsArray(key string, n uint64, defaultSingleValue ...uint64) []uint64 {
|
||||
var singleValue *uint64
|
||||
if v, ok := keyValueUntyped(kv, key); ok {
|
||||
switch v := v.(type) {
|
||||
case *array:
|
||||
switch v.values[0].(type) {
|
||||
case int32, uint32, uint64:
|
||||
values, ok := AsUint64Array(v.values)
|
||||
if ok {
|
||||
return values
|
||||
}
|
||||
default:
|
||||
slog.Warn("unexpected array value type", "key", key, "type", reflect.TypeOf(v))
|
||||
}
|
||||
case uint32:
|
||||
val := uint64(v)
|
||||
singleValue = &val
|
||||
case int32:
|
||||
val := uint64(v)
|
||||
singleValue = &val
|
||||
}
|
||||
}
|
||||
if singleValue == nil {
|
||||
slog.Warn("falling back to default")
|
||||
singleValue = &defaultSingleValue[0]
|
||||
}
|
||||
|
||||
values := make([]uint64, n)
|
||||
for i := range values {
|
||||
values[i] = *singleValue
|
||||
}
|
||||
|
||||
return values
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]string, r.size)
|
||||
@@ -225,24 +141,16 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
if val, ok := keyValueUntyped(kv, key); ok {
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
func keyValueUntyped(kv KV, key string) (any, bool) {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
if val, ok := kv[key]; ok {
|
||||
return val, true
|
||||
return val.(T)
|
||||
}
|
||||
|
||||
return nil, false
|
||||
slog.Warn("key not found", "key", key, "default", defaultValue[0])
|
||||
return defaultValue[0]
|
||||
}
|
||||
|
||||
type Tensors struct {
|
||||
@@ -510,22 +418,12 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
|
||||
|
||||
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
|
||||
embedding := f.KV().EmbeddingLength()
|
||||
heads := f.KV().HeadCounts()
|
||||
headsKV := f.KV().HeadCountKVs()
|
||||
heads := f.KV().HeadCount()
|
||||
headsKV := f.KV().HeadCountKV()
|
||||
vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array).size)
|
||||
|
||||
embeddingHeads := f.KV().EmbeddingHeadCount()
|
||||
maxEmbeddingHeads, ok := MaxValue(embeddingHeads)
|
||||
if !ok {
|
||||
maxEmbeddingHeads = 1
|
||||
slog.Warn("failed to get max embedding heads")
|
||||
}
|
||||
embeddingHeadsK := f.KV().EmbeddingHeadCountK()
|
||||
maxEmbeddingHeadsK, ok := MaxValue(embeddingHeadsK)
|
||||
if !ok {
|
||||
maxEmbeddingHeadsK = 1
|
||||
slog.Warn("failed to get max embedding headsK")
|
||||
}
|
||||
embeddingHeadsV := f.KV().EmbeddingHeadCountV()
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
@@ -533,30 +431,19 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
|
||||
kv = make([]uint64, f.KV().BlockCount())
|
||||
for i := range kv {
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||
}
|
||||
|
||||
maxHeads, ok := MaxValue(heads)
|
||||
if !ok {
|
||||
maxHeads = 1
|
||||
slog.Warn("failed to get max heads")
|
||||
}
|
||||
maxHeadsKV, ok := MaxValue(headsKV)
|
||||
if !ok {
|
||||
maxHeadsKV = 1
|
||||
slog.Warn("failed to get max headsKV")
|
||||
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
|
||||
switch f.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = max(
|
||||
4*batch*(1+4*embedding+context*(1+maxHeads)),
|
||||
4*batch*(1+4*embedding+context*(1+heads)),
|
||||
4*batch*(embedding+vocab),
|
||||
)
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*maxHeads+maxEmbeddingHeads*maxHeadsKV),
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embeddingHeads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
@@ -564,16 +451,16 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// mixtral 8x22b
|
||||
ff := uint64(f.KV()["llama.feed_forward_length"].(uint32))
|
||||
partialOffload = max(
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+maxHeadsKV+embedding+context+maxEmbeddingHeads*maxHeadsKV),
|
||||
4*(context*batch*maxHeads+context*maxEmbeddingHeads*maxHeadsKV+batch*1024+maxEmbeddingHeads*maxHeadsKV*batch),
|
||||
3*ffnGateExpsWeight.Size()+4*batch*(2*ff+headsKV+embedding+context+embeddingHeads*headsKV),
|
||||
4*(context*batch*heads+context*embeddingHeads*headsKV+batch*1024+embeddingHeads*headsKV*batch),
|
||||
)
|
||||
} else if ffnGateWeight, ok := layers["blk.0"]["ffn_gate.0.weight"]; ok {
|
||||
// mixtral 8x7b
|
||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+maxHeads) + 2*maxHeadsKV + ffnGateWeight1)
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||
partialOffload = max(
|
||||
4*batch*(3+maxEmbeddingHeads*maxHeadsKV+embedding+context*(1+maxHeads)+ffnGateWeight1)+(embedding*embedding+3*embedding*maxHeadsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+maxHeads))+embedding*(6*context*maxHeadsKV/maxHeads+embedding*9/16),
|
||||
4*batch*(3+embeddingHeads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||
)
|
||||
}
|
||||
case "mllama":
|
||||
@@ -582,7 +469,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
|
||||
for i := range kv {
|
||||
if slices.Contains(crossAttentionLayers, uint32(i)) {
|
||||
kv[i] = headsKV[i] * (embeddingHeadsK[i] + embeddingHeadsV[i]) *
|
||||
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
|
||||
4 * // sizeof(float32)
|
||||
visionTokens *
|
||||
tiles
|
||||
@@ -590,7 +477,7 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
}
|
||||
|
||||
fullOffload = max(
|
||||
4*batch*(2+3*embedding+maxEmbeddingHeadsK*maxHeads+context*(1+maxHeads)),
|
||||
4*batch*(2+3*embedding+embeddingHeadsK*heads+context*(1+heads)),
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab),
|
||||
)
|
||||
@@ -604,23 +491,23 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
|
||||
partialOffload = max(
|
||||
4*(batch*
|
||||
(2*embedding+1+context*(1+maxHeads)+maxEmbeddingHeadsK*maxHeads)+
|
||||
(2*embedding+1+context*(1+heads)+embeddingHeadsK*heads)+
|
||||
ropeFreqsCount+
|
||||
maxEmbeddingHeadsK*context*maxHeadsKV),
|
||||
embeddingHeadsK*context*headsKV),
|
||||
// vocab graph
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
case "gemma", "gemma2", "gemma3":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+context+context*maxHeads+2*embedding+2*maxEmbeddingHeadsK*maxHeads),
|
||||
4*batch*(2+context+context*heads+2*embedding+2*embeddingHeadsK*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*embedding*batch+embedding*vocab*105/128+4*vocab*batch,
|
||||
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeads+context+context*maxHeads)+
|
||||
4*maxEmbeddingHeadsK*context*8+
|
||||
embedding*embedding*maxEmbeddingHeadsK*maxHeads*9/16,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*heads+context+context*heads)+
|
||||
4*embeddingHeadsK*context*8+
|
||||
embedding*embeddingHeadsK*heads*9/16,
|
||||
)
|
||||
|
||||
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
|
||||
@@ -632,42 +519,42 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
|
||||
// layers are the smaller local (sliding) layers.
|
||||
if (i+1)%gemma3GlobalCacheCount != 0 {
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK[i]+embeddingHeadsV[i])*headsKV[i]) * bytesPerElement)
|
||||
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
|
||||
}
|
||||
}
|
||||
}
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+maxHeads)),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+maxHeads))+4*embedding*context+embedding*embedding*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*maxHeads),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+maxHeads))+embedding*(1+context)),
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*maxHeads),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(2*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2+3*embedding+context+context*maxHeads),
|
||||
4*batch*(2+3*embedding+context+context*heads),
|
||||
)
|
||||
case "stablelm":
|
||||
fullOffload = 4 * batch * (context*(1+maxHeads) + 3*embedding + 2)
|
||||
fullOffload = 4 * batch * (context*(1+heads) + 3*embedding + 2)
|
||||
partialOffload = max(
|
||||
4*batch*(vocab+2*embedding),
|
||||
fullOffload,
|
||||
@@ -675,12 +562,12 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
case "deepseek2":
|
||||
fullOffload = max(
|
||||
4*batch*(3*embedding+vocab),
|
||||
4*batch*(3*embedding+2+context*(1+maxHeadsKV)+2*maxEmbeddingHeadsK*maxHeadsKV),
|
||||
4*batch*(3*embedding+2+context*(1+headsKV)+2*embeddingHeadsK*headsKV),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(2*embedding+1+2*maxEmbeddingHeadsK*maxHeadsKV+context+context*maxHeadsKV)+4*maxEmbeddingHeadsK*context*maxHeadsKV+embedding*embedding*maxEmbeddingHeadsK*maxHeadsKV*9/16,
|
||||
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
|
||||
)
|
||||
case "chatglm":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
@@ -691,8 +578,8 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
4*batch*(2+
|
||||
2*embedding+
|
||||
context+
|
||||
context*maxHeads+
|
||||
maxEmbeddingHeadsK*maxHeads+
|
||||
context*heads+
|
||||
embeddingHeadsK*heads+
|
||||
qkvBias.Shape[0]),
|
||||
)
|
||||
|
||||
@@ -700,11 +587,11 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
|
||||
partialOffload,
|
||||
4*batch*(1+
|
||||
2*embedding+
|
||||
maxEmbeddingHeadsK*maxHeads+
|
||||
embeddingHeadsK*heads+
|
||||
context+
|
||||
context*maxHeads)+
|
||||
4*maxEmbeddingHeadsK*context+
|
||||
4*context*maxEmbeddingHeadsK+
|
||||
context*heads)+
|
||||
4*embeddingHeadsK*context+
|
||||
4*context*embeddingHeadsK+
|
||||
4*qkvBias.Shape[0],
|
||||
)
|
||||
}
|
||||
@@ -776,15 +663,9 @@ func (f GGML) SupportsFlashAttention() bool {
|
||||
}
|
||||
|
||||
// Check head counts match and are non-zero
|
||||
headCount := f.KV().HeadCounts()
|
||||
embeddingHeadCountK := f.KV().EmbeddingHeadCountK()
|
||||
embeddingHeadCountV := f.KV().EmbeddingHeadCountV()
|
||||
for i := range headCount {
|
||||
if embeddingHeadCountK[i] != embeddingHeadCountV[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
headCountK := f.KV().EmbeddingHeadCountK()
|
||||
headCountV := f.KV().EmbeddingHeadCountV()
|
||||
return headCountK != 0 && headCountV != 0 && headCountK == headCountV
|
||||
}
|
||||
|
||||
// kvCacheBytesPerElement returns the number of bytes per element for a given KV cache type
|
||||
@@ -798,54 +679,3 @@ func kvCacheBytesPerElement(cacheType string) float64 {
|
||||
return 2 // f16 (default)
|
||||
}
|
||||
}
|
||||
|
||||
func AsUint64Array(v []any) ([]uint64, bool) {
|
||||
switch v[0].(type) {
|
||||
case uint32:
|
||||
values := make([]uint64, len(v))
|
||||
for i, v := range v {
|
||||
values[i] = uint64(v.(uint32))
|
||||
}
|
||||
return values, true
|
||||
case uint64:
|
||||
values := make([]uint64, len(v))
|
||||
for i, v := range v {
|
||||
values[i] = v.(uint64)
|
||||
}
|
||||
return values, true
|
||||
case int32:
|
||||
values := make([]uint64, len(v))
|
||||
for i, val := range v {
|
||||
val := val.(int32)
|
||||
if val < 0 {
|
||||
slog.Warn("negative value in int32 array", "value", val)
|
||||
return nil, false
|
||||
}
|
||||
values[i] = uint64(val)
|
||||
}
|
||||
return values, true
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
|
||||
func MaxValue(values []uint64) (uint64, bool) {
|
||||
if len(values) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
|
||||
max := values[0]
|
||||
for _, v := range values {
|
||||
if v > max {
|
||||
max = v
|
||||
}
|
||||
}
|
||||
return max, true
|
||||
}
|
||||
|
||||
func FillArray[T any](value T, n int) []T {
|
||||
values := make([]T, n)
|
||||
for i := range values {
|
||||
values[i] = value
|
||||
}
|
||||
return values
|
||||
}
|
||||
|
||||
@@ -52,8 +52,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
embedCtx := ctx
|
||||
|
||||
var genwg sync.WaitGroup
|
||||
genwg.Add(1)
|
||||
go func() {
|
||||
genwg.Add(1)
|
||||
defer genwg.Done()
|
||||
slog.Info("Starting generate request")
|
||||
DoGenerate(ctx, t, client, req, resp, 45*time.Second, 5*time.Second)
|
||||
@@ -71,8 +71,8 @@ func TestMaxQueue(t *testing.T) {
|
||||
counterMu := sync.Mutex{}
|
||||
var embedwg sync.WaitGroup
|
||||
for i := 0; i < threadCount; i++ {
|
||||
embedwg.Add(1)
|
||||
go func(i int) {
|
||||
embedwg.Add(1)
|
||||
defer embedwg.Done()
|
||||
slog.Info("embed started", "id", i)
|
||||
embedReq := api.EmbeddingRequest{
|
||||
|
||||
@@ -56,9 +56,8 @@ type Cache interface {
|
||||
|
||||
// StartForward is called before the start of the model's forward pass.
|
||||
// For each token in the coming batch, there must be a corresponding
|
||||
// entry in positions and seqs. reserve is to preallocate memory
|
||||
// without actually storing data in the cache.
|
||||
StartForward(ctx ml.Context, batch input.Batch, reserve bool) error
|
||||
// entry in positions and seqs.
|
||||
StartForward(ctx ml.Context, batch input.Batch) error
|
||||
|
||||
// CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq
|
||||
CopyPrefix(srcSeq, dstSeq int, len int32)
|
||||
|
||||
@@ -146,60 +146,51 @@ func (c *Causal) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
} else {
|
||||
// If we are reserving memory, don't update any of the cache metadata but set the size
|
||||
// to the worst case.
|
||||
c.curLoc = 0
|
||||
c.curCellRange.min = 0
|
||||
c.curCellRange.max = len(c.cells) - 1
|
||||
}
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
if errors.Is(err, ErrKvCacheFull) {
|
||||
c.defrag()
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.curCellRange = newRange()
|
||||
for i, pos := range batch.Positions {
|
||||
seq := batch.Sequences[i]
|
||||
|
||||
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
|
||||
|
||||
seqRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
seqRange = newRange()
|
||||
}
|
||||
|
||||
if c.curLoc+i > seqRange.max {
|
||||
seqRange.max = c.curLoc + i
|
||||
}
|
||||
if seqRange.max > c.curCellRange.max {
|
||||
c.curCellRange.max = seqRange.max
|
||||
}
|
||||
|
||||
if c.curLoc+i < seqRange.min {
|
||||
seqRange.min = c.curLoc + i
|
||||
}
|
||||
if seqRange.min < c.curCellRange.min {
|
||||
c.curCellRange.min = seqRange.min
|
||||
}
|
||||
c.cellRanges[seq] = seqRange
|
||||
}
|
||||
|
||||
c.curMask, err = c.buildMask(ctx)
|
||||
|
||||
return err
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model/input"
|
||||
)
|
||||
@@ -280,7 +281,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
context := backend.NewContext()
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs}, false)
|
||||
err := cache.StartForward(context, input.Batch{Positions: test.pos, Sequences: test.seqs})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
@@ -314,7 +315,7 @@ func TestCanResume(t *testing.T) {
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
}, false)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
@@ -341,7 +342,7 @@ func TestCanResume(t *testing.T) {
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
}, false)
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
@@ -371,8 +372,14 @@ func TestCanResume(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
type testBackend struct {
|
||||
ml.Backend
|
||||
type testBackend struct{}
|
||||
|
||||
func (b *testBackend) Config() fs.Config {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (b *testBackend) Get(name string) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (b *testBackend) NewContext() ml.Context {
|
||||
@@ -383,10 +390,12 @@ func (b *testBackend) NewContextSize(int) ml.Context {
|
||||
return &testContext{}
|
||||
}
|
||||
|
||||
type testContext struct {
|
||||
ml.Context
|
||||
func (b *testBackend) SystemInfo() string {
|
||||
return "not implemented"
|
||||
}
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
@@ -431,8 +440,6 @@ func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
func (c *testContext) Reserve() error { return nil }
|
||||
|
||||
func (c *testContext) MaxGraphNodes() int {
|
||||
return 10
|
||||
}
|
||||
@@ -440,8 +447,6 @@ func (c *testContext) MaxGraphNodes() int {
|
||||
func (c *testContext) Close() {}
|
||||
|
||||
type testTensor struct {
|
||||
ml.Tensor
|
||||
|
||||
dtype ml.DType
|
||||
elementSize int
|
||||
data []float32
|
||||
@@ -469,6 +474,10 @@ func (t *testTensor) DType() ml.DType {
|
||||
return t.dtype
|
||||
}
|
||||
|
||||
func (t *testTensor) Bytes() []byte {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Floats() []float32 {
|
||||
out := make([]float32, len(t.data))
|
||||
copy(out, t.data)
|
||||
@@ -493,6 +502,64 @@ func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
func (t *testTensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) LayerNorm(ctx ml.Context, weight, bias ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RMSNorm(ctx ml.Context, weight ml.Tensor, eps float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool1D(ctx ml.Context, k, s, p int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) IM2Col(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Cos(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) Sin(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) Tanh(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) GELU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
func (t *testTensor) SILU(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
|
||||
func (t *testTensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
offset /= t.elementSize
|
||||
|
||||
@@ -515,7 +582,43 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
return view
|
||||
}
|
||||
|
||||
func (t *testTensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Repeat(ctx ml.Context, dim, n int) ml.Tensor { panic("not implemented") }
|
||||
|
||||
func (t *testTensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
copy(t2.(*testTensor).data, t.data)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (t *testTensor) Duplicate(ctx ml.Context) ml.Tensor { panic("not implemented") }
|
||||
|
||||
@@ -27,11 +27,6 @@ type EncoderCache struct {
|
||||
// anything will be stored)
|
||||
curPos int32
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// ** cache metadata **
|
||||
|
||||
// was something stored in the cache?
|
||||
@@ -88,14 +83,12 @@ func (c *EncoderCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
func (c *EncoderCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
// We work with the most recent image
|
||||
if len(batch.Multimodal) > 0 {
|
||||
c.curPos = batch.Positions[batch.Multimodal[len(batch.Multimodal)-1].Index]
|
||||
}
|
||||
|
||||
c.curReserve = reserve
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,10 +105,8 @@ func (c *EncoderCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
if !c.curReserve {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
}
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
|
||||
@@ -41,9 +41,9 @@ func (c *WrapperCache) Close() {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
func (c *WrapperCache) StartForward(ctx ml.Context, batch input.Batch) error {
|
||||
for i, cache := range c.caches {
|
||||
err := cache.StartForward(ctx, batch, reserve)
|
||||
err := cache.StartForward(ctx, batch)
|
||||
if err != nil {
|
||||
// unwind on error - Remove with endIndex set to math.MaxInt32 does not fail
|
||||
for j := i - 1; j >= 0; j-- {
|
||||
|
||||
@@ -149,7 +149,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = f.KV().GQAMax() * kvTotal / 6
|
||||
graphPartialOffload = f.KV().GQA() * kvTotal / 6
|
||||
}
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
|
||||
@@ -97,13 +97,6 @@ type Context interface {
|
||||
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
|
||||
// Reserve is analogous to Compute but rather than executing a
|
||||
// graph, simply preallocates memory. Typically called with a
|
||||
// worst case graph to ensure all resources are available for
|
||||
// for future inference.
|
||||
Reserve() error
|
||||
|
||||
MaxGraphNodes() int
|
||||
Close()
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -43,12 +42,8 @@ func devices() []*C.struct_ggml_backend_device {
|
||||
}
|
||||
|
||||
type Backend struct {
|
||||
meta *fsggml.GGML
|
||||
|
||||
sched *C.struct_ggml_backend_sched
|
||||
schedBackends []*C.struct_ggml_backend
|
||||
schedBufts []*C.struct_ggml_backend_buffer_type
|
||||
|
||||
meta *fsggml.GGML
|
||||
sched *C.struct_ggml_backend_sched
|
||||
tensors map[string]*C.struct_ggml_tensor
|
||||
|
||||
// input is the backend used for inputs
|
||||
@@ -286,10 +281,6 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate memory from device %v for model weights", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
}
|
||||
|
||||
C.ggml_backend_buffer_set_usage(b, C.GGML_BACKEND_BUFFER_USAGE_WEIGHTS)
|
||||
bbs[c] = b
|
||||
}
|
||||
@@ -328,14 +319,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
tts[i] = tt
|
||||
}
|
||||
|
||||
// Create a new FD for each goroutine so that each FD is read sequentially, rather than
|
||||
// seeking around within an FD shared between all goroutines.
|
||||
file, err := os.Open(r.Name())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
sr := io.NewSectionReader(file, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||
sr := io.NewSectionReader(r, int64(meta.Tensors().Offset+t.Offset), int64(t.Size()))
|
||||
bts := make([]byte, 128*format.KibiByte)
|
||||
|
||||
var s uint64
|
||||
@@ -394,6 +378,8 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
schedBackends = append(schedBackends, b)
|
||||
schedBufts = append(schedBufts, bt)
|
||||
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(b)), "buffer_type", C.GoString(C.ggml_backend_buft_name(bt)))
|
||||
|
||||
if C.ggml_backend_is_cpu(b) {
|
||||
// set number of threads for cpu backend
|
||||
C.ggml_backend_cpu_set_n_threads(b, C.int(Threads(params.NumThreads)))
|
||||
@@ -412,9 +398,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
|
||||
C.size_t(maxGraphNodes),
|
||||
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
|
||||
),
|
||||
schedBackends: schedBackends,
|
||||
schedBufts: schedBufts,
|
||||
input: deviceBufferTypes[input.d],
|
||||
input: deviceBufferTypes[input.d],
|
||||
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
|
||||
m := make(map[int]*C.struct_ggml_backend_buffer_type)
|
||||
for i, layer := range layers {
|
||||
@@ -539,24 +523,6 @@ func (c Context) Compute(tensors ...ml.Tensor) {
|
||||
}
|
||||
}
|
||||
|
||||
func (c Context) Reserve() error {
|
||||
if !C.ggml_backend_sched_reserve(c.b.sched, c.graph) {
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
return errors.New("failed to reserve graph")
|
||||
}
|
||||
|
||||
slog.Debug("compute graph", "nodes", C.ggml_graph_n_nodes(c.graph), "splits", C.ggml_backend_sched_get_n_splits(c.b.sched))
|
||||
for i := range c.b.schedBackends {
|
||||
size := C.ggml_backend_sched_get_buffer_size(c.b.sched, c.b.schedBackends[i])
|
||||
slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])),
|
||||
"size", format.HumanBytes2(uint64(size)))
|
||||
}
|
||||
|
||||
C.ggml_backend_sched_reset(c.b.sched)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c Context) MaxGraphNodes() int {
|
||||
return c.maxGraphNodes
|
||||
}
|
||||
@@ -574,9 +540,9 @@ func pad(length, pad C.size_t) C.size_t {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
func (c Context) newTensor(dtype ml.DType, shape []int) ml.Tensor {
|
||||
if c.buft == nil {
|
||||
panic("set Input or Layer before creating tensors")
|
||||
panic("set Input, Output, or Layer before creating tensors")
|
||||
}
|
||||
|
||||
var cdtype uint32
|
||||
@@ -597,7 +563,7 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
|
||||
if len(shape) < 1 || shape[0] == 0 {
|
||||
var shape C.int64_t = 0
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}, nil
|
||||
return &Tensor{b: c.b, t: C.ggml_new_tensor(c.ctx, cdtype, 1, &shape)}
|
||||
} else if len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -611,29 +577,16 @@ func (c Context) newTensor(dtype ml.DType, shape []int) (ml.Tensor, error) {
|
||||
t := C.ggml_new_tensor(c.ctx, cdtype, C.int(len(shape)), shapeToGGML(shape))
|
||||
size := pad(C.ggml_backend_buft_get_alloc_size(c.buft, t), C.ggml_backend_buft_get_alignment(c.buft))
|
||||
b := C.ggml_backend_buft_alloc_buffer(c.buft, size)
|
||||
if b == nil {
|
||||
return nil, fmt.Errorf("unable to allocate %v from device %v for new tensor", format.HumanBytes2(uint64(size)), C.GoString(C.ggml_backend_buft_name(c.buft)))
|
||||
}
|
||||
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
return &Tensor{b: c.b, t: t}, nil
|
||||
return &Tensor{b: c.b, t: t}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return t
|
||||
return c.newTensor(dtype, shape)
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
t, err := c.newTensor(dtype, shape)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
t := c.newTensor(dtype, shape)
|
||||
C.ggml_set_zero(t.(*Tensor).t)
|
||||
return t
|
||||
}
|
||||
@@ -661,11 +614,7 @@ func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := c.newTensor(ml.DTypeF32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeF32, shape)
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
@@ -678,11 +627,7 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t, err := c.newTensor(ml.DTypeI32, shape)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
t := c.newTensor(ml.DTypeI32, shape)
|
||||
if len(s) > 0 {
|
||||
C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t))
|
||||
}
|
||||
|
||||
@@ -299,7 +299,7 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
|
||||
|
||||
cache := m.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, false)
|
||||
err := cache.StartForward(ctx, batch)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -281,31 +281,27 @@ func TestChatMiddleware(t *testing.T) {
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -11,13 +11,10 @@ import (
|
||||
"os"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/sync/errgroup"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
|
||||
@@ -147,25 +144,12 @@ func fileDigestMap(path string) (map[string]string, error) {
|
||||
files = []string{path}
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var g errgroup.Group
|
||||
g.SetLimit(max(runtime.GOMAXPROCS(0)-1, 1))
|
||||
for _, f := range files {
|
||||
g.Go(func() error {
|
||||
digest, err := digestForFile(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
fl[f] = digest
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
return nil, err
|
||||
digest, err := digestForFile(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
fl[f] = digest
|
||||
}
|
||||
|
||||
return fl, nil
|
||||
|
||||
@@ -448,7 +448,7 @@ func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor)
|
||||
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
|
||||
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
|
||||
func (m *mockCache) Close() {}
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error { return nil }
|
||||
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
|
||||
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
|
||||
func (m *mockCache) SetConfig(ml.CacheConfig) {}
|
||||
func (m *mockCache) CanResume(seq int, pos int32) bool { return true }
|
||||
|
||||
@@ -728,51 +728,6 @@ func (m *multiLPath) String() string {
|
||||
return strings.Join(*m, ", ")
|
||||
}
|
||||
|
||||
func (s *Server) reserveWorstCaseGraph() error {
|
||||
ctx := s.model.Backend().NewContext()
|
||||
defer ctx.Close()
|
||||
|
||||
var batch input.Batch
|
||||
|
||||
inputs := make([]int32, s.batchSize)
|
||||
batch.Positions = make([]int32, len(inputs))
|
||||
batch.Sequences = make([]int, len(inputs))
|
||||
for i := range inputs {
|
||||
batch.Positions[i] = int32(i)
|
||||
}
|
||||
|
||||
batch.Outputs = make([]int32, s.parallel)
|
||||
for i := range batch.Outputs {
|
||||
batch.Outputs[i] = int32(i)
|
||||
}
|
||||
|
||||
var err error
|
||||
batch.Inputs, err = ctx.Input().FromIntSlice(inputs, len(inputs))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cache := s.model.Config().Cache
|
||||
if cache != nil {
|
||||
err := cache.StartForward(ctx, batch, true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
t, err := s.model.Forward(ctx, batch)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = ctx.Forward(t).Reserve()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) loadModel(
|
||||
ctx context.Context,
|
||||
mpath string,
|
||||
@@ -810,11 +765,6 @@ func (s *Server) loadModel(
|
||||
s.seqs = make([]*Sequence, s.parallel)
|
||||
s.seqsSem = semaphore.NewWeighted(int64(s.parallel))
|
||||
|
||||
err = s.reserveWorstCaseGraph()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
s.status = llm.ServerStatusReady
|
||||
s.ready.Done()
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template/parse"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -62,6 +63,7 @@ type Model struct {
|
||||
Digest string
|
||||
Options map[string]any
|
||||
Messages []api.Message
|
||||
ToolPrefix string
|
||||
|
||||
Template *template.Template
|
||||
}
|
||||
@@ -260,7 +262,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model := &Model{
|
||||
m := &Model{
|
||||
Name: mp.GetFullTagname(),
|
||||
ShortName: mp.GetShortTagname(),
|
||||
Digest: digest,
|
||||
@@ -279,7 +281,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
|
||||
if err := json.NewDecoder(configFile).Decode(&m.Config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@@ -292,16 +294,16 @@ func GetModel(name string) (*Model, error) {
|
||||
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.model":
|
||||
model.ModelPath = filename
|
||||
model.ParentModel = layer.From
|
||||
m.ModelPath = filename
|
||||
m.ParentModel = layer.From
|
||||
case "application/vnd.ollama.image.embed":
|
||||
// Deprecated in versions > 0.1.2
|
||||
// TODO: remove this warning in a future version
|
||||
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
|
||||
case "application/vnd.ollama.image.adapter":
|
||||
model.AdapterPaths = append(model.AdapterPaths, filename)
|
||||
m.AdapterPaths = append(m.AdapterPaths, filename)
|
||||
case "application/vnd.ollama.image.projector":
|
||||
model.ProjectorPaths = append(model.ProjectorPaths, filename)
|
||||
m.ProjectorPaths = append(m.ProjectorPaths, filename)
|
||||
case "application/vnd.ollama.image.prompt",
|
||||
"application/vnd.ollama.image.template":
|
||||
bts, err := os.ReadFile(filename)
|
||||
@@ -309,7 +311,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.Template, err = template.Parse(string(bts))
|
||||
m.Template, err = template.Parse(string(bts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -319,7 +321,7 @@ func GetModel(name string) (*Model, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
model.System = string(bts)
|
||||
m.System = string(bts)
|
||||
case "application/vnd.ollama.image.params":
|
||||
params, err := os.Open(filename)
|
||||
if err != nil {
|
||||
@@ -328,7 +330,7 @@ func GetModel(name string) (*Model, error) {
|
||||
defer params.Close()
|
||||
|
||||
// parse model options parameters into a map so that we can see which fields have been specified explicitly
|
||||
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
|
||||
if err = json.NewDecoder(params).Decode(&m.Options); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.messages":
|
||||
@@ -338,7 +340,7 @@ func GetModel(name string) (*Model, error) {
|
||||
}
|
||||
defer msgs.Close()
|
||||
|
||||
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
|
||||
if err = json.NewDecoder(msgs).Decode(&m.Messages); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
case "application/vnd.ollama.image.license":
|
||||
@@ -346,11 +348,50 @@ func GetModel(name string) (*Model, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
model.License = append(model.License, string(bts))
|
||||
m.License = append(m.License, string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
return model, nil
|
||||
capabilities := m.Capabilities()
|
||||
if slices.Contains(capabilities, model.CapabilityTools) {
|
||||
m.addToolPrefix()
|
||||
}
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// HasToolPrefix checks if the completion starts with the tool prefix, ignoring whitespace
|
||||
func (m *Model) HasToolPrefix(sb strings.Builder) bool {
|
||||
text := strings.ReplaceAll(strings.TrimSpace(sb.String()), " ", "")
|
||||
toolString := strings.ReplaceAll(strings.TrimSpace(m.ToolPrefix), " ", "")
|
||||
|
||||
if len(text) < len(toolString) {
|
||||
return text == toolString[:len(text)]
|
||||
}
|
||||
return text[:len(toolString)] == toolString
|
||||
}
|
||||
|
||||
// Figure out what's between the start of the tools block, and the json response, and use it as a marker. Usually that's
|
||||
// {- if .ToolCalls}this text{ range .ToolCalls}or maybe this text{{.name}}
|
||||
func (m *Model) addToolPrefix() {
|
||||
// create a subtree from the node that ranges over .ToolCalls
|
||||
var previousNode parse.Node
|
||||
toolCallsTemplate := m.Template.Subtree(func(node parse.Node) bool {
|
||||
if rangeNode, ok := node.(*parse.RangeNode); ok {
|
||||
return slices.Contains(template.Identifiers(rangeNode.Pipe), "ToolCalls")
|
||||
}
|
||||
previousNode = node
|
||||
return false
|
||||
})
|
||||
if textNode, ok := previousNode.(*parse.TextNode); ok {
|
||||
m.ToolPrefix = strings.TrimSpace(textNode.String())
|
||||
}
|
||||
if len(m.ToolPrefix) == 0 && len(toolCallsTemplate.Root.Nodes) > 0 {
|
||||
rangeNode, ok := toolCallsTemplate.Root.Nodes[0].(*parse.RangeNode)
|
||||
if ok && len(rangeNode.List.Nodes) > 0 {
|
||||
m.ToolPrefix = rangeNode.List.Nodes[0].String()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func CopyModel(src, dst model.Name) error {
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
@@ -28,19 +29,20 @@ func readFile(t *testing.T, base, name string) *bytes.Buffer {
|
||||
func TestExecuteWithTools(t *testing.T) {
|
||||
p := filepath.Join("testdata", "tools")
|
||||
cases := []struct {
|
||||
model string
|
||||
output string
|
||||
ok bool
|
||||
model string
|
||||
output string
|
||||
ok bool
|
||||
wellFormed bool
|
||||
}{
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
|
||||
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false},
|
||||
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true, false},
|
||||
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"To }]`, false, false},
|
||||
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
|
||||
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, false},
|
||||
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||
{"command-r-plus", "Action: ```json" + `
|
||||
[
|
||||
{
|
||||
@@ -58,16 +60,17 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
||||
}
|
||||
}
|
||||
]
|
||||
` + "```", true},
|
||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
|
||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
|
||||
` + "```", true, true},
|
||||
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true, true},
|
||||
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false, false},
|
||||
{"llama3-groq-tool-use", `<tool_call>
|
||||
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
|
||||
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
|
||||
</tool_call>`, true},
|
||||
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
|
||||
{"nemotron", `<toolcall>{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true},
|
||||
</tool_call>`, true, true},
|
||||
{"xlam", `### Response:
|
||||
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true, true},
|
||||
{"nemotron", `<toolcall> {"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]} </toolcall>`, true, true},
|
||||
}
|
||||
|
||||
var tools []api.Tool
|
||||
@@ -119,6 +122,21 @@ The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`,
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("prefix", func(t *testing.T) {
|
||||
m := &Model{Template: tmpl}
|
||||
m.addToolPrefix()
|
||||
|
||||
if tt.wellFormed {
|
||||
if len(m.ToolPrefix) == 0 {
|
||||
t.Fatalf("No tool prefix detected")
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(strings.TrimSpace(tt.output), m.ToolPrefix) {
|
||||
t.Fatalf("incorrect tool prefix: \"%s\", \"%s\"", m.ToolPrefix, tt.output)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("parse", func(t *testing.T) {
|
||||
m := &Model{Template: tmpl}
|
||||
actual, ok := m.parseToolCalls(tt.output)
|
||||
@@ -177,3 +195,64 @@ func TestParseObjects(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddToolPrefix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
template string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "prefix_from_previous_text_node",
|
||||
template: `Previous text node{{- range .ToolCalls}}{{.name}}{{end}}`,
|
||||
want: "Previous text node",
|
||||
},
|
||||
{
|
||||
name: "prefix_from_range_node",
|
||||
template: `{{- range .ToolCalls}}[TOOL_CALLS]{{.name}}{{end}}`,
|
||||
want: "[TOOL_CALLS]",
|
||||
},
|
||||
{
|
||||
name: "prefix_with_extra_whitespace",
|
||||
template: ` Previous text with spaces {{- range .ToolCalls}}{{.name}}{{end}}`,
|
||||
want: "Previous text with spaces",
|
||||
},
|
||||
{
|
||||
name: "prefix_with_newlines",
|
||||
template: "First line\nSecond line\n{{- range .ToolCalls}}{{.name}}{{end}}",
|
||||
want: "First line\nSecond line",
|
||||
},
|
||||
{
|
||||
name: "tool_calls_json_template",
|
||||
template: `{{ if .Content }}{{ .Content }}{{- else if .ToolCalls }}<tool_call>
|
||||
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}</tool_call>
|
||||
{{ end }}`,
|
||||
want: `<tool_call>`,
|
||||
},
|
||||
{
|
||||
name: "mistral_tool_calls_template",
|
||||
template: `{{- if .Content }} {{ .Content }}
|
||||
{{- else if .ToolCalls }}[TOOL_CALLS] [
|
||||
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
|
||||
{{- end }}]
|
||||
{{- end }}</s>`,
|
||||
want: "[TOOL_CALLS] [",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tmpl, err := template.Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to parse template: %v", err)
|
||||
}
|
||||
|
||||
m := &Model{Template: tmpl}
|
||||
m.addToolPrefix()
|
||||
|
||||
if m.ToolPrefix != tt.want {
|
||||
t.Errorf("incorrect tool prefix:\ngot: %q\nwant: %q", m.ToolPrefix, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1526,6 +1526,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
var sb strings.Builder
|
||||
var toolCallIndex int = 0
|
||||
var mightBeTools bool = true
|
||||
buf := make([]api.ChatResponse, 0)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
@@ -1551,18 +1553,29 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
// TODO: tool call checking and filtering should be moved outside of this callback once streaming
|
||||
// however this was a simple change for now without reworking streaming logic of this (and other)
|
||||
// handlers
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 {
|
||||
// If we know we're not streaming
|
||||
if req.Stream != nil && !*req.Stream || len(req.Tools) == 0 || !mightBeTools {
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
|
||||
sb.WriteString(r.Content)
|
||||
|
||||
// Buffer up responses while we're unsure whether to stream.
|
||||
buf = append(buf, res)
|
||||
|
||||
// not a tools response, continue streaming.
|
||||
if !m.HasToolPrefix(sb) {
|
||||
mightBeTools = false
|
||||
for _, item := range buf {
|
||||
ch <- item
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Streaming tool calls:
|
||||
// If tools are recognized, use a flag to track the sending of a tool downstream
|
||||
// This ensures that content is cleared from the message on the last chunk sent
|
||||
sb.WriteString(r.Content)
|
||||
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
for i := range toolCalls {
|
||||
@@ -1573,8 +1586,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
sb.Reset()
|
||||
ch <- res
|
||||
return
|
||||
} else {
|
||||
if !strings.HasPrefix(sb.String(), "{") {
|
||||
ch <- res
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if r.Done {
|
||||
// Send any remaining content if no tool calls were detected
|
||||
if toolCallIndex == 0 {
|
||||
|
||||
@@ -370,31 +370,27 @@ func TestGenerateChat(t *testing.T) {
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -471,31 +467,27 @@ func TestGenerateChat(t *testing.T) {
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
} `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]struct {
|
||||
Type api.PropertyType `json:"type"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Description string `json:"description"`
|
||||
Enum []any `json:"enum,omitempty"`
|
||||
Type string `json:"type"`
|
||||
Description string `json:"description"`
|
||||
Enum []string `json:"enum,omitempty"`
|
||||
}{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Type: "string",
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
Type: "string",
|
||||
Enum: []string{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -667,19 +667,13 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any {
|
||||
return finished
|
||||
}
|
||||
|
||||
type ByDurationAndName []*runnerRef
|
||||
type ByDuration []*runnerRef
|
||||
|
||||
func (a ByDurationAndName) Len() int { return len(a) }
|
||||
func (a ByDurationAndName) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByDurationAndName) Less(i, j int) bool {
|
||||
// Primary sort by session duration (uint64 to handle negatives)
|
||||
d1 := uint64(a[i].sessionDuration)
|
||||
d2 := uint64(a[j].sessionDuration)
|
||||
if d1 != d2 {
|
||||
return d1 < d2
|
||||
}
|
||||
// Secondary sort by model path lex order
|
||||
return a[i].modelPath < a[j].modelPath
|
||||
func (a ByDuration) Len() int { return len(a) }
|
||||
func (a ByDuration) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByDuration) Less(i, j int) bool {
|
||||
// uint64 to turn negative time (never unload) to largest
|
||||
return uint64(a[i].sessionDuration) < uint64(a[j].sessionDuration)
|
||||
}
|
||||
|
||||
// TODO - future consideration to pick runners based on size
|
||||
@@ -781,7 +775,7 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef {
|
||||
|
||||
// In the future we can enhance the algorithm to be smarter about picking the optimal runner to unload
|
||||
// e.g., if we have multiple options, will one make room for the request?
|
||||
sort.Sort(ByDurationAndName(runnerList))
|
||||
sort.Sort(ByDuration(runnerList))
|
||||
|
||||
// First try to find a runner that's already idle
|
||||
for _, runner := range runnerList {
|
||||
|
||||
Reference in New Issue
Block a user