Compare commits

..

21 Commits

Author SHA1 Message Date
Bruce MacDonald
159821594c Update ml/backend.go 2025-04-02 09:46:19 -07:00
Bruce MacDonald
cbeb2aab4f Update backend.go 2025-04-01 15:08:59 -07:00
Bruce MacDonald
96df15edfc ml: structured rope config to allow specifying context len
This commit refactors the Rotary Position Embedding (RoPE) implementation across the codebase to use a structured configuration approach instead of individual parameters.

Key changes:
- Add new RoPEConfig struct with fields for dimension, type, base frequency, and scaling
- Add RopeType enum to formalize different RoPE implementation variants
- Add YarnConfig struct and related configuration for YaRN (Yet Another RoPE extensioN) context extension
- Update RoPE method signature across all tensor interfaces and implementations
- Refactor all model implementations (llama, gemma2, gemma3, mllama) to use the new configuration structure

This change improves code organization, makes the RoPE configuration more explicit, and provides better support for different RoPE variants and context extension methods.
2025-04-01 14:03:48 -07:00
Ilian
c001b98087 docs: add TagSpaces to community integrations (#9983) 2025-03-31 17:28:59 -07:00
Abyss-c0re
23fc8e92eb docs: add DeepShell to community projects (#9955)
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2025-03-31 17:23:04 -07:00
湛露先生
4059a297a6 discover: /proc/cpuinfo file open and close. (#9950)
Signed-off-by: zhanluxianshen <zhanluxianshen@163.com>
2025-03-31 17:07:42 -07:00
Bruce MacDonald
66b2539238 runner: clear cache when shift is not possible (#9433)
Clear KV cache when shift operation is not supported by model.
Added KvCacheCanShift() check to handle models that can't perform cache shifts,
falling back to full cache clear while preserving logical token history to
maintain expected behavior when context window fills up.
2025-03-31 12:54:45 -07:00
Blake Mizerany
ef27d52e79 server/internal/client/ollama: cache completed chunks (#9933)
This change adds tracking of download chunks during the pull process so
that subsequent pulls can skip downloading already completed chunks.
This works across restarts of ollama.

Currently, download state will be lost if a prune is triggered during a
pull (e.g. restart or remove). This issue should be addressed in a
follow-up PR.
2025-03-30 23:54:54 -07:00
Jesse Gross
b2a465296d runner: Release semaphore and improve error messages on failures
If we have an error after creating a new sequence but before
finding a slot for it, we return without releasing the semaphore.
This reduces our parallel sequences and eventually leads to deadlock.

In practice this should never happen because once we have acquired
the semaphore, we should always be able to find a slot. However, the
code is clearly not correct.
2025-03-30 19:21:54 -07:00
Jesse Gross
5d097277ef ollamarunner: Ensure batch size limits are not exceeded
With the llama runner, we can generate up to NUM_PARALLEL batches
at once, which will then get broken up to into individual batches
to get executed by llama.cpp (i.e. we add up to 2048 tokens and
this gets split into 4 batches of 512 tokens at default settings).

This splitting can improve parallelism on multi-GPU systems because
the individual batches can move though the pipeline without blocking
on the first one to fully complete. However, we don't yet support
this in the Ollama runner, partially because it makes it hard to
enforce model-specified batch constraints, which didn't exist
previously.

The result is that we will try to execute the full, unsplit batch.
This could result in out of memory or insufficient KV cache space
errors.

This triggers batch breaking when the total inputs from all sequences
exceeds the batch size, rather than per-sequence. In order to ensure
fairness, it also reintroduces round-robinning around sequences so
that we don't let one busy sequence starve the others.
2025-03-30 19:21:01 -07:00
Leandro Borges Ferreira
071a9872cb readme: add Writeopia to community integrations (#10042) 2025-03-30 17:28:06 -07:00
CYJiang
0bd0454ea7 server: organize error types (#9465)
Co-authored-by: Bruce MacDonald <brucewmacdonald@gmail.com>
2025-03-28 11:50:22 -07:00
Jesse Gross
01aa788722 ml: Remove Output from Context interface
Model implementations should use Input for all of their tensors
supplied to the model. This includes tensors that relate to the
outputs, which is confusing since there is also an Output funciton.

Since Output is only used internally in GGML and not used by any
model implementations, we can remove it from the interface to
reduce confusion.
2025-03-27 12:19:43 -07:00
saman-amd
ead27aa9fe Add gfx1200 & gfx1201 support on linux (#9878) 2025-03-27 07:35:19 -07:00
Parth Sareen
b816ff86c9 docs: make context length faq readable (#10006) 2025-03-26 17:34:18 -07:00
molbal
e5d84fb90b docs: add molbal/orca-cli to community integrations (#9909) 2025-03-26 13:39:01 -07:00
Hengky Steen
dd66712e31 docs: add ollamb to community projects 2025-03-26 13:38:05 -07:00
Jesse Gross
f66216e399 ggml: Support heterogeneous KV cache layer sizes in memory estimation
Gemma3 uses sliding windows for its context on 5/6 layers, significantly
reducing memory usage but leading to uneven usage across layers,
which makes allocation to the correct GPU difficult. We currently
estimate very conservatively by assuming all layers are consistent
at the max size.

Llama3.2-vision is also inconsistent between self attention and cross
attention layers - at moment, we calculate the correct total size
and then average this across layers. In some cases, this may lead
to crashes if a large layer is placed on a GPU sized by the average.

This allows memory estimation to calculate per-layer KV cache size
and take this account when placing layers onto GPUs. We already do
this for weights that vary per-tensor, so this is a logical extension.

Fixes #9730
Fixes #9890
2025-03-26 13:16:03 -07:00
Jesse Gross
f4f0992b6e llm: Fix debug logging for memory estimates 2025-03-26 13:16:03 -07:00
Jesse Gross
1feff61977 kvcache: Sliding window cache only needs a single batch total
When computing the size of the cache for sliding window attention,
we don't need to multiple the batch size by the number of parallel
sequences - the batch size is constant.

This also simplifies the check for whether to allocate the cache
size based on capacity or window size as the batch size is already
incorporated into the capacity when handled by the runner.
2025-03-26 13:16:03 -07:00
copeland3300
5e0b904e88 docs: add flags to example linux log output command (#9852) 2025-03-25 09:52:23 -07:00
67 changed files with 1029 additions and 3436 deletions

View File

@@ -86,9 +86,9 @@ if(CMAKE_CUDA_COMPILER)
)
endif()
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a):xnack[+-]$"
set(WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX "^gfx(906|908|90a|1200|1201):xnack[+-]$"
CACHE STRING
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a):xnack[+-]$\"."
"Regular expression describing AMDGPU_TARGETS not supported on Windows. Override to force building these targets. Default \"^gfx(906|908|90a|1200|1201):xnack[+-]$\"."
)
check_language(HIP)
@@ -97,7 +97,7 @@ if(CMAKE_HIP_COMPILER)
find_package(hip REQUIRED)
if(NOT AMDGPU_TARGETS)
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012])$")
list(FILTER AMDGPU_TARGETS INCLUDE REGEX "^gfx(900|94[012]|101[02]|1030|110[012]|120[01])$")
elseif(WIN32 AND WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX)
list(FILTER AMDGPU_TARGETS EXCLUDE REGEX ${WINDOWS_AMDGPU_TARGETS_EXCLUDE_REGEX})
endif()

View File

@@ -56,7 +56,7 @@
"name": "ROCm 6",
"inherits": [ "ROCm" ],
"cacheVariables": {
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
"AMDGPU_TARGETS": "gfx900;gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx906:xnack-;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-"
}
}
],

View File

@@ -285,6 +285,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
- [Chatbot UI v2](https://github.com/mckaywrigley/chatbot-ui)
- [Typescript UI](https://github.com/ollama-interface/Ollama-Gui?tab=readme-ov-file)
@@ -394,6 +395,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
- [Ellama](https://github.com/zeozeozeo/ellama) (Friendly native app to chat with an Ollama instance)
- [screenpipe](https://github.com/mediar-ai/screenpipe) Build agents powered by your screen history
- [Ollamb](https://github.com/hengkysteen/ollamb) (Simple yet rich in features, cross-platform built with Flutter and designed for Ollama. Try the [web demo](https://hengkysteen.github.io/demo/ollamb/).)
- [Writeopia](https://github.com/Writeopia/Writeopia) (Text editor with integration with Ollama)
### Cloud
@@ -433,7 +436,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [SwollamaCLI](https://github.com/marcusziade/Swollama) bundled with the Swollama Swift package. [Demo](https://github.com/marcusziade/Swollama?tab=readme-ov-file#cli-usage)
- [aichat](https://github.com/sigoden/aichat) All-in-one LLM CLI tool featuring Shell Assistant, Chat-REPL, RAG, AI tools & agents, with access to OpenAI, Claude, Gemini, Ollama, Groq, and more.
- [PowershAI](https://github.com/rrg92/powershai) PowerShell module that brings AI to terminal on Windows, including support for Ollama
- [DeepShell](https://github.com/Abyss-c0re/deepshell) Your self-hosted AI assistant. Interactive Shell, Files and Folders analysis.
- [orbiton](https://github.com/xyproto/orbiton) Configuration-free text editor and IDE with support for tab completion with Ollama.
- [orca-cli](https://github.com/molbal/orca-cli) Ollama Registry CLI Application - Browse, pull and download models from Ollama Registry in your terminal.
### Apple Vision Pro

View File

@@ -111,6 +111,7 @@ func GetCPUDetails() ([]CPU, error) {
if err != nil {
return nil, err
}
defer file.Close()
return linuxCPUDetails(file)
}
@@ -168,13 +169,11 @@ func linuxCPUDetails(file io.Reader) ([]CPU, error) {
for id, s := range socketByID {
s.CoreCount = len(coreBySocket[id])
s.ThreadCount = 0
for _, tc := range threadsByCoreBySocket[id] {
s.ThreadCount += tc
}
// This only works if HT is enabled, consider a more reliable model, maybe cache size comparisons?
efficiencyCoreCount := 0
for _, threads := range threadsByCoreBySocket[id] {
s.ThreadCount += threads
if threads == 1 {
efficiencyCoreCount++
}

View File

@@ -20,7 +20,13 @@ Please refer to the [GPU docs](./gpu.md).
## How can I specify the context window size?
By default, Ollama uses a context window size of 2048 tokens. This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context length to 8K, use: `OLLAMA_CONTEXT_LENGTH=8192 ollama serve`.
By default, Ollama uses a context window size of 2048 tokens.
This can be overridden with the `OLLAMA_CONTEXT_LENGTH` environment variable. For example, to set the default context window to 8K, use:
```shell
OLLAMA_CONTEXT_LENGTH=8192 ollama serve
```
To change this when using `ollama run`, use `/set parameter`:

View File

@@ -9,7 +9,7 @@ cat ~/.ollama/logs/server.log
On **Linux** systems with systemd, the logs can be found with this command:
```shell
journalctl -u ollama --no-pager
journalctl -u ollama --no-pager --follow --pager-end
```
When you run Ollama in a **container**, the logs go to stdout/stderr in the container:

View File

@@ -413,7 +413,7 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, int64, error) {
}, offset, nil
}
func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialOffload, fullOffload uint64) {
func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) {
embedding := f.KV().EmbeddingLength()
heads := f.KV().HeadCount()
headsKV := f.KV().HeadCountKV()
@@ -426,7 +426,10 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
layers := f.Tensors().GroupLayers()
bytesPerElement := kvCacheBytesPerElement(kvCacheType)
kv = uint64(float64(context*f.KV().BlockCount()*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
kv = make([]uint64, f.KV().BlockCount())
for i := range kv {
kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
switch f.KV().Architecture() {
case "llama":
@@ -460,16 +463,14 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
case "mllama":
var visionTokens, tiles uint64 = 1601, 4
if crossAttentionLayers, ok := f.KV()["mllama.attention.cross_attention_layers"].(*array); ok {
kv = headsKV *
(embeddingHeadsK + embeddingHeadsV) * // one for K, one for V
(2* // sizeof(float16)
(f.KV().BlockCount()-uint64(crossAttentionLayers.size))* // num non-cross attention layers
context +
4* // sizeof(float32)
uint64(crossAttentionLayers.size)* // num cross attention layers
visionTokens*
tiles)
crossAttentionLayers := f.KV().Uints("attention.cross_attention_layers")
for i := range kv {
if slices.Contains(crossAttentionLayers, uint32(i)) {
kv[i] = headsKV * (embeddingHeadsK + embeddingHeadsV) *
4 * // sizeof(float32)
visionTokens *
tiles
}
}
fullOffload = max(
@@ -505,6 +506,20 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
4*embeddingHeadsK*context*8+
embedding*embeddingHeadsK*heads*9/16,
)
// Gemma2 also has sliding window attention but we only have an optimized implementation in the Ollama
// engine. Gemma3 always uses the Ollama engine.
if f.KV().Architecture() == "gemma3" {
const gemma3GlobalCacheCount = 6
slidingWindow := (uint64(numParallel) * uint64(f.KV().Uint("attention.sliding_window"))) + batch
for i := range kv {
// Every 6th layer is a global layer, which is the full context size that has already been set. The other
// layers are the smaller local (sliding) layers.
if (i+1)%gemma3GlobalCacheCount != 0 {
kv[i] = uint64(float64(slidingWindow*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement)
}
}
}
case "command-r":
fullOffload = max(
4*batch*(embedding+vocab),

View File

@@ -1,22 +0,0 @@
//go:build go1.24
package grammar
import "testing"
func BenchmarkFromSchema(b *testing.B) {
for tt := range testCases(b) {
b.Run("", func(b *testing.B) {
s := []byte(tt.schema)
b.ReportAllocs()
for b.Loop() {
_, err := FromSchema(nil, s)
if err != nil {
b.Fatalf("GrammarFromSchema: %v", err)
}
}
})
return
}
}

View File

@@ -1,227 +0,0 @@
package grammar
import (
"bytes"
"encoding/json"
"fmt"
"iter"
"strconv"
"github.com/ollama/ollama/grammar/jsonschema"
)
const jsonTerms = `
# Unicode
#
# Unicode characters can be specified directly in the grammar, for example
# hiragana ::= [ぁ-ゟ], or with escapes: 8-bit (\xXX), 16-bit (\uXXXX) or 32-bit
# (\UXXXXXXXX).
unicode ::= \x{hex}{2} | \u{hex}{4} | \U{hex}{8}
# JSON grammar from RFC 7159
null ::= "null"
object ::= "{" (kv ("," kv)*)? "}"
array ::= "[" (value ("," value)*)? "]"
kv ::= string ":" value
integer ::= "0" | [1-9] [0-9]*
number ::= "-"? integer frac? exp?
frac ::= "." [0-9]+
exp ::= ("e" | "E") ("+" | "-") [0-9]+
string ::= "\"" char* "\""
escape ::= ["/" | "b" | "f" | "n" | "r" | "t" | unicode]
char ::= [^"\\] | escape
space ::= (" " | "\t" | "\n" | "\r")*
hex ::= [0-9] | [a-f] | [A-F]
boolean ::= "true" | "false"
value ::= object | array | string | number | boolean | "null"
# User-defined
`
// FromSchema generates a grammar from a JSON schema.
func FromSchema(buf []byte, jsonSchema []byte) ([]byte, error) {
var s *jsonschema.Schema
if err := json.Unmarshal(jsonSchema, &s); err != nil {
return nil, err
}
var g builder
// "root" is the only rule that is guaranteed to exist, so we start
// with its length for padding, and then adjust it as we go.
g.pad = len("root")
for id := range dependencies("root", s) {
g.pad = max(g.pad, len(id))
}
g.b.WriteString(jsonTerms)
ids := make(map[*jsonschema.Schema]string)
for id, s := range dependencies("root", s) {
ids[s] = id
g.define(id)
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
}
g.define("root")
if err := fromSchema(&g, ids, s); err != nil {
return nil, err
}
g.define("") // finalize the last rule
return g.b.Bytes(), nil
}
func fromSchema(g *builder, ids map[*jsonschema.Schema]string, s *jsonschema.Schema) error {
switch typ := s.EffectiveType(); typ {
case "array":
if len(s.PrefixItems) == 0 && s.Items == nil {
g.u("array")
} else {
g.q("[")
for i, s := range s.PrefixItems {
if i > 0 {
g.q(",")
}
g.u(ids[s])
}
if s.Items != nil {
g.u("(")
if len(s.PrefixItems) > 0 {
g.q(",")
}
g.u(ids[s.Items])
g.u(")*")
}
g.q("]")
}
case "object":
if len(s.Properties) == 0 {
g.u("object")
} else {
g.q("{")
for i, p := range s.Properties {
name := ids[p]
if i > 0 {
g.q(",")
}
g.q(p.Name)
g.q(":")
g.u(name)
}
g.q("}")
}
case "number":
buildConstrainedNumber(g, s)
case "string":
if len(s.Enum) == 0 {
g.u("string")
} else {
g.u("(")
for i, e := range s.Enum {
if i > 0 {
g.q("|")
}
g.q(string(e))
}
g.u(")")
}
case "boolean", "value", "null", "integer":
g.u(typ)
default:
return fmt.Errorf("%s: unsupported type %q", s.Name, typ)
}
return nil
}
// dependencies returns a sequence of all child dependencies of the schema in
// post-order.
//
// The first value is the id/pointer to the dependency, and the second value
// is the schema.
func dependencies(id string, s *jsonschema.Schema) iter.Seq2[string, *jsonschema.Schema] {
return func(yield func(string, *jsonschema.Schema) bool) {
for i, p := range s.Properties {
id := fmt.Sprintf("%s_%d", id, i)
for did, d := range dependencies(id, p) {
if !yield(did, d) {
return
}
}
if !yield(id, p) {
return
}
}
for i, p := range s.PrefixItems {
id := fmt.Sprintf("tuple_%d", i)
for did, d := range dependencies(id, p) {
id := fmt.Sprintf("%s_%s", id, did)
if !yield(id, d) {
return
}
}
if !yield(id, p) {
return
}
}
if s.Items != nil {
id := fmt.Sprintf("%s_tuple_%d", id, len(s.PrefixItems))
for did, d := range dependencies(id, s.Items) {
if !yield(did, d) {
return
}
}
if !yield(id, s.Items) {
return
}
}
}
}
type builder struct {
b bytes.Buffer
pad int
rules int
items int
}
// define terminates the current rule, if any, and then either starts a new
// rule or does nothing else if the name is empty.
func (b *builder) define(name string) {
if b.rules > 0 {
b.b.WriteString(";\n")
}
if name == "" {
return
}
fmt.Fprintf(&b.b, "% -*s", b.pad, name)
b.b.WriteString(" ::=")
b.rules++
b.items = 0
}
// quote appends a terminal to the current rule.
func (b *builder) q(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(strconv.Quote(s))
}
// u appends a non-terminal to the current rule.
func (b *builder) u(s string) {
if b.items > 0 {
b.b.WriteString(" ")
}
b.b.WriteString(" ")
b.b.WriteString(s)
}
func buildConstrainedNumber(b *builder, s *jsonschema.Schema) {
if s.Minimum == 0 && s.Maximum == 0 {
b.u("TODO")
} else {
b.u("number")
}
}

View File

@@ -1,75 +0,0 @@
package grammar
import (
"bufio"
"cmp"
"iter"
"strings"
"testing"
_ "embed"
"github.com/ollama/ollama/grammar/internal/diff"
)
func TestFromSchema(t *testing.T) {
for tt := range testCases(t) {
t.Run(tt.name, func(t *testing.T) {
g, err := FromSchema(nil, []byte(tt.schema))
if err != nil {
t.Fatalf("FromSchema: %v", err)
}
got := string(g)
got = strings.TrimPrefix(got, jsonTerms)
if got != tt.want {
t.Logf("schema:\n%s", tt.schema)
t.Fatal(string(diff.Diff("got", []byte(got), "want", []byte(tt.want))))
}
})
}
}
type testCase struct {
name string
schema string
want string
}
//go:embed testdata/schemas.txt
var tests string
func testCases(t testing.TB) iter.Seq[testCase] {
t.Helper()
return func(yield func(testCase) bool) {
t.Helper()
sc := bufio.NewScanner(strings.NewReader(tests))
name := ""
for sc.Scan() {
line := strings.TrimSpace(sc.Text())
if line == "" {
name = ""
continue
}
if line[0] == '#' {
name = cmp.Or(name, strings.TrimSpace(line[1:]))
continue
}
s := sc.Text()
g := ""
for sc.Scan() {
line = strings.TrimSpace(sc.Text())
if line == "" || line[0] == '#' {
break
}
g += sc.Text() + "\n"
}
if !yield(testCase{name, s, g}) {
return
}
name = strings.TrimSpace(strings.TrimPrefix(line, "#"))
}
if err := sc.Err(); err != nil {
t.Fatalf("error reading tests: %v", err)
}
}
}

View File

@@ -1,261 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"fmt"
"sort"
"strings"
)
// A pair is a pair of values tracked for both the x and y side of a diff.
// It is typically a pair of line indexes.
type pair struct{ x, y int }
// Diff returns an anchored diff of the two texts old and new
// in the “unified diff” format. If old and new are identical,
// Diff returns a nil slice (no output).
//
// Unix diff implementations typically look for a diff with
// the smallest number of lines inserted and removed,
// which can in the worst case take time quadratic in the
// number of lines in the texts. As a result, many implementations
// either can be made to run for a long time or cut off the search
// after a predetermined amount of work.
//
// In contrast, this implementation looks for a diff with the
// smallest number of “unique” lines inserted and removed,
// where unique means a line that appears just once in both old and new.
// We call this an “anchored diff” because the unique lines anchor
// the chosen matching regions. An anchored diff is usually clearer
// than a standard diff, because the algorithm does not try to
// reuse unrelated blank lines or closing braces.
// The algorithm also guarantees to run in O(n log n) time
// instead of the standard O(n²) time.
//
// Some systems call this approach a “patience diff,” named for
// the “patience sorting” algorithm, itself named for a solitaire card game.
// We avoid that name for two reasons. First, the name has been used
// for a few different variants of the algorithm, so it is imprecise.
// Second, the name is frequently interpreted as meaning that you have
// to wait longer (to be patient) for the diff, meaning that it is a slower algorithm,
// when in fact the algorithm is faster than the standard one.
func Diff(oldName string, old []byte, newName string, new []byte) []byte {
if bytes.Equal(old, new) {
return nil
}
x := lines(old)
y := lines(new)
// Print diff header.
var out bytes.Buffer
fmt.Fprintf(&out, "diff %s %s\n", oldName, newName)
fmt.Fprintf(&out, "--- %s\n", oldName)
fmt.Fprintf(&out, "+++ %s\n", newName)
// Loop over matches to consider,
// expanding each match to include surrounding lines,
// and then printing diff chunks.
// To avoid setup/teardown cases outside the loop,
// tgs returns a leading {0,0} and trailing {len(x), len(y)} pair
// in the sequence of matches.
var (
done pair // printed up to x[:done.x] and y[:done.y]
chunk pair // start lines of current chunk
count pair // number of lines from each side in current chunk
ctext []string // lines for current chunk
)
for _, m := range tgs(x, y) {
if m.x < done.x {
// Already handled scanning forward from earlier match.
continue
}
// Expand matching lines as far as possible,
// establishing that x[start.x:end.x] == y[start.y:end.y].
// Note that on the first (or last) iteration we may (or definitely do)
// have an empty match: start.x==end.x and start.y==end.y.
start := m
for start.x > done.x && start.y > done.y && x[start.x-1] == y[start.y-1] {
start.x--
start.y--
}
end := m
for end.x < len(x) && end.y < len(y) && x[end.x] == y[end.y] {
end.x++
end.y++
}
// Emit the mismatched lines before start into this chunk.
// (No effect on first sentinel iteration, when start = {0,0}.)
for _, s := range x[done.x:start.x] {
ctext = append(ctext, "-"+s)
count.x++
}
for _, s := range y[done.y:start.y] {
ctext = append(ctext, "+"+s)
count.y++
}
// If we're not at EOF and have too few common lines,
// the chunk includes all the common lines and continues.
const C = 3 // number of context lines
if (end.x < len(x) || end.y < len(y)) &&
(end.x-start.x < C || (len(ctext) > 0 && end.x-start.x < 2*C)) {
for _, s := range x[start.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
continue
}
// End chunk with common lines for context.
if len(ctext) > 0 {
n := end.x - start.x
if n > C {
n = C
}
for _, s := range x[start.x : start.x+n] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = pair{start.x + n, start.y + n}
// Format and emit chunk.
// Convert line numbers to 1-indexed.
// Special case: empty file shows up as 0,0 not 1,0.
if count.x > 0 {
chunk.x++
}
if count.y > 0 {
chunk.y++
}
fmt.Fprintf(&out, "@@ -%d,%d +%d,%d @@\n", chunk.x, count.x, chunk.y, count.y)
for _, s := range ctext {
out.WriteString(s)
}
count.x = 0
count.y = 0
ctext = ctext[:0]
}
// If we reached EOF, we're done.
if end.x >= len(x) && end.y >= len(y) {
break
}
// Otherwise start a new chunk.
chunk = pair{end.x - C, end.y - C}
for _, s := range x[chunk.x:end.x] {
ctext = append(ctext, " "+s)
count.x++
count.y++
}
done = end
}
return out.Bytes()
}
// lines returns the lines in the file x, including newlines.
// If the file does not end in a newline, one is supplied
// along with a warning about the missing newline.
func lines(x []byte) []string {
l := strings.SplitAfter(string(x), "\n")
if l[len(l)-1] == "" {
l = l[:len(l)-1]
} else {
// Treat last line as having a message about the missing newline attached,
// using the same text as BSD/GNU diff (including the leading backslash).
l[len(l)-1] += "\n\\ No newline at end of file\n"
}
return l
}
// tgs returns the pairs of indexes of the longest common subsequence
// of unique lines in x and y, where a unique line is one that appears
// once in x and once in y.
//
// The longest common subsequence algorithm is as described in
// Thomas G. Szymanski, “A Special Case of the Maximal Common
// Subsequence Problem,” Princeton TR #170 (January 1975),
// available at https://research.swtch.com/tgs170.pdf.
func tgs(x, y []string) []pair {
// Count the number of times each string appears in a and b.
// We only care about 0, 1, many, counted as 0, -1, -2
// for the x side and 0, -4, -8 for the y side.
// Using negative numbers now lets us distinguish positive line numbers later.
m := make(map[string]int)
for _, s := range x {
if c := m[s]; c > -2 {
m[s] = c - 1
}
}
for _, s := range y {
if c := m[s]; c > -8 {
m[s] = c - 4
}
}
// Now unique strings can be identified by m[s] = -1+-4.
//
// Gather the indexes of those strings in x and y, building:
// xi[i] = increasing indexes of unique strings in x.
// yi[i] = increasing indexes of unique strings in y.
// inv[i] = index j such that x[xi[i]] = y[yi[j]].
var xi, yi, inv []int
for i, s := range y {
if m[s] == -1+-4 {
m[s] = len(yi)
yi = append(yi, i)
}
}
for i, s := range x {
if j, ok := m[s]; ok && j >= 0 {
xi = append(xi, i)
inv = append(inv, j)
}
}
// Apply Algorithm A from Szymanski's paper.
// In those terms, A = J = inv and B = [0, n).
// We add sentinel pairs {0,0}, and {len(x),len(y)}
// to the returned sequence, to help the processing loop.
J := inv
n := len(xi)
T := make([]int, n)
L := make([]int, n)
for i := range T {
T[i] = n + 1
}
for i := range n {
k := sort.Search(n, func(k int) bool {
return T[k] >= J[i]
})
T[k] = J[i]
L[i] = k + 1
}
k := 0
for _, v := range L {
if k < v {
k = v
}
}
seq := make([]pair, 2+k)
seq[1+k] = pair{len(x), len(y)} // sentinel at end
lastj := n
for i := n - 1; i >= 0; i-- {
if L[i] == k && J[i] < lastj {
seq[k] = pair{xi[i], yi[J[i]]}
k--
}
}
seq[0] = pair{0, 0} // sentinel at start
return seq
}

View File

@@ -1,44 +0,0 @@
// Copyright 2022 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package diff
import (
"bytes"
"path/filepath"
"testing"
"golang.org/x/tools/txtar"
)
func clean(text []byte) []byte {
text = bytes.ReplaceAll(text, []byte("$\n"), []byte("\n"))
text = bytes.TrimSuffix(text, []byte("^D\n"))
return text
}
func Test(t *testing.T) {
files, _ := filepath.Glob("testdata/*.txt")
if len(files) == 0 {
t.Fatalf("no testdata")
}
for _, file := range files {
t.Run(filepath.Base(file), func(t *testing.T) {
a, err := txtar.ParseFile(file)
if err != nil {
t.Fatal(err)
}
if len(a.Files) != 3 || a.Files[2].Name != "diff" {
t.Fatalf("%s: want three files, third named \"diff\"", file)
}
diffs := Diff(a.Files[0].Name, clean(a.Files[0].Data), a.Files[1].Name, clean(a.Files[1].Data))
want := clean(a.Files[2].Data)
if !bytes.Equal(diffs, want) {
t.Fatalf("%s: have:\n%s\nwant:\n%s\n%s", file,
diffs, want, Diff("have", diffs, "want", want))
}
})
}
}

View File

@@ -1,13 +0,0 @@
-- old --
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -0,0 +1,3 @@
+a
+b
+c

View File

@@ -1,13 +0,0 @@
-- old --
a
b
c
-- new --
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +0,0 @@
-a
-b
-c

View File

@@ -1,35 +0,0 @@
Example from Hunt and McIlroy, “An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
-- old --
a
b
c
d
e
f
g
-- new --
w
a
b
x
y
z
e
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,7 @@
+w
a
b
-c
-d
+x
+y
+z
e
-f
-g

View File

@@ -1,40 +0,0 @@
-- old --
a
b
c
d
e
f
-- new --
a
B
C
d
e
f
-- diff --
diff old new
--- old
+++ new
@@ -1,8 +1,8 @@
a
$
-b
-
-c
+B
+
+C
$
d
$

View File

@@ -1,38 +0,0 @@
-- old --
1
2
3
4
5
6
7
eight
nine
ten
eleven
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -5,7 +5,6 @@
5
6
7
-eight
-nine
-ten
-eleven
+8
+9
+10

View File

@@ -1,9 +0,0 @@
-- old --
a
b
c^D
-- new --
a
b
c^D
-- diff --

View File

@@ -1,18 +0,0 @@
-- old --
a
b
c
-- new --
a
b
c^D
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
+c
\ No newline at end of file

View File

@@ -1,18 +0,0 @@
-- old --
a
b
c^D
-- new --
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,3 +1,3 @@
a
b
-c
\ No newline at end of file
+c

View File

@@ -1,62 +0,0 @@
-- old --
1
2
3
4
5
6
7
8
9
10
11
12
13
14
14½
15
16
17
18
19
20
-- new --
1
2
3
4
5
6
8
9
10
11
12
13
14
17
18
19
20
-- diff --
diff old new
--- old
+++ new
@@ -4,7 +4,6 @@
4
5
6
-7
8
9
10
@@ -12,9 +11,6 @@
12
13
14
-14½
-15
-16
17
18
19

View File

@@ -1,5 +0,0 @@
-- old --
hello world
-- new --
hello world
-- diff --

View File

@@ -1,34 +0,0 @@
-- old --
e
pi
4
5
6
7
8
9
10
-- new --
1
2
3
4
5
6
7
8
9
10
-- diff --
diff old new
--- old
+++ new
@@ -1,5 +1,6 @@
-e
-pi
+1
+2
+3
4
5
6

View File

@@ -1,40 +0,0 @@
Another example from Hunt and McIlroy,
“An Algorithm for Differential File Comparison.”
https://www.cs.dartmouth.edu/~doug/diff.pdf
Anchored diff gives up on finding anything,
since there are no unique lines.
-- old --
a
b
c
a
b
b
a
-- new --
c
a
b
a
b
c
-- diff --
diff old new
--- old
+++ new
@@ -1,7 +1,6 @@
-a
-b
-c
-a
-b
-b
-a
+c
+a
+b
+a
+b
+c

View File

@@ -1,171 +0,0 @@
package jsonschema
import (
"bytes"
"encoding/json"
"errors"
)
// Schema holds a JSON schema.
type Schema struct {
// Name is the name of the property. For the parent/root property, this
// is "root". For child properties, this is the name of the property.
Name string `json:"-"`
// Type is the type of the property.
//
// TODO: Union types (e.g. make this a []string).
Type string
// PrefixItems is a list of schemas for each item in a tuple. By
// default, the tuple is "closed." unless Items is set to true or a
// valid Schema.
PrefixItems []*Schema
// Items is the schema for each item in a list.
//
// If it is missing, or its JSON value is "null" or "false", it is nil.
// If the JSON value is "true", it is set to the empty Schema. If the
// JSON value is an object, it will be decoded as a Schema.
Items *Schema
// MinItems specifies the minimum number of items allowed in a list.
MinItems int
// MaxItems specifies the maximum number of items allowed in a list.
MaxItems int
// Properties is the schema for each property of an object.
Properties []*Schema
// Format is the format of the property. This is used to validate the
// property against a specific format.
//
// It is the callers responsibility to validate the property against
// the format.
Format string
// Minimum specifies the minimum value for numeric properties.
Minimum float64
// Maximum specifies the maximum value for numeric properties.
Maximum float64
// Enum is a list of valid values for the property.
Enum []json.RawMessage
}
func (s *Schema) UnmarshalJSON(data []byte) error {
type S Schema
w := struct {
Properties props
Items items
*S
}{
S: (*S)(s),
}
if err := json.Unmarshal(data, &w); err != nil {
return err
}
if w.Items.set {
s.Items = &w.Items.Schema
}
s.Properties = w.Properties
return nil
}
type items struct {
Schema
set bool
}
func (s *items) UnmarshalJSON(data []byte) error {
switch b := data[0]; b {
case 't':
*s = items{set: true}
case '{':
type I items
if err := json.Unmarshal(data, (*I)(s)); err != nil {
return err
}
s.set = true
case 'n', 'f':
default:
return errors.New("invalid Items")
}
return nil
}
// EffectiveType returns the effective type of the schema. If the Type field is
// not empty, it is returned; otherwise:
//
// - If the schema has both Properties and Items, it returns an empty string.
// - If the schema has Properties, it returns "object".
// - If the schema has Items, it returns "array".
// - If the schema has neither Properties nor Items, it returns "value".
//
// The returned string is never empty.
func (d *Schema) EffectiveType() string {
if d.Type == "" {
if len(d.Properties) > 0 {
return "object"
}
if len(d.PrefixItems) > 0 || d.Items != nil {
return "array"
}
return "value"
}
return d.Type
}
// props is an ordered list of properties. The order of the properties
// is the order in which they were defined in the schema.
type props []*Schema
var _ json.Unmarshaler = (*props)(nil)
func (v *props) UnmarshalJSON(data []byte) error {
if len(data) == 0 {
return nil
}
if data[0] != '{' {
return errors.New("expected object")
}
d := json.NewDecoder(bytes.NewReader(data))
// TODO(bmizerany): Consider DisallowUnknownFields. Currently, we, like
// llama.cpp, ignore unknown fields, which could be lead to unexpected
// behavior for clients of this package, since they may not be aware
// that "additionalFields", "itemsPrefix", etc, are being ignored.
//
// For now, just do what llama.cpp does.
t, err := d.Token()
if err != nil {
return err
}
if t != json.Delim('{') {
return errors.New("expected object")
}
for d.More() {
// Use the first token (map key) as the property name, then
// decode the rest of the object fields into a Schema and
// append.
t, err := d.Token()
if err != nil {
return err
}
if t == json.Delim('}') {
return nil
}
s := &Schema{
Name: t.(string),
}
if err := d.Decode(s); err != nil {
return err
}
*v = append(*v, s)
}
return nil
}

View File

@@ -1,104 +0,0 @@
package jsonschema
import (
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/google/go-cmp/cmp"
)
const testSchemaBasic = `
{
"properties": {
"tupleClosedEmpty": { "prefixItems": [] },
"tupleClosedMissing": { "prefixItems": [{}] },
"tupleClosedNull": { "prefixItems": [{}], "items": null },
"tupleClosedFalse": { "prefixItems": [{}], "items": false },
"tupleOpenTrue": { "prefixItems": [{}], "items": true },
"tupleOpenEmpty": { "prefixItems": [{}], "items": {} },
"tupleOpenTyped": { "prefixItems": [{}], "items": {"type": "boolean"} },
"tupleOpenMax": { "prefixItems": [{}], "items": true, "maxItems": 3},
"array": { "items": {"type": "number"} },
"null": { "type": "null" },
"string": { "type": "string" },
"boolean": { "type": "boolean" }
}
}
`
func TestSchemaUnmarshal(t *testing.T) {
var got *Schema
if err := json.Unmarshal([]byte(testSchemaBasic), &got); err != nil {
t.Fatalf("Unmarshal: %v", err)
}
want := &Schema{
Properties: []*Schema{
{Name: "tupleClosedEmpty", PrefixItems: []*Schema{}, Items: nil},
{Name: "tupleClosedMissing", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedNull", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleClosedFalse", PrefixItems: []*Schema{{}}, Items: nil},
{Name: "tupleOpenTrue", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenEmpty", PrefixItems: []*Schema{{}}, Items: &Schema{}},
{Name: "tupleOpenTyped", PrefixItems: []*Schema{{}}, Items: &Schema{Type: "boolean"}},
{Name: "tupleOpenMax", PrefixItems: []*Schema{{}}, Items: &Schema{}, MaxItems: 3},
{Name: "array", Items: &Schema{Type: "number"}},
{Name: "null", Type: "null"},
{Name: "string", Type: "string"},
{Name: "boolean", Type: "boolean"},
},
}
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("(-want, +got)\n%s", diff)
}
}
func TestEffectiveType(t *testing.T) {
const schema = `
{"properties": {
"o": {"type": "object"},
"a": {"type": "array"},
"n": {"type": "number"},
"s": {"type": "string"},
"z": {"type": "null"},
"b": {"type": "boolean"},
"t0": {"prefixItems": [{}], "items": {"type": "number"}},
"t1": {"items": {"type": "number"}, "maxItems": 3},
"v": {"maxItems": 3}
}}
`
var s *Schema
if err := json.Unmarshal([]byte(schema), &s); err != nil {
t.Fatalf("json.Unmarshal: %v", err)
}
var got []string
for _, p := range s.Properties {
got = append(got, p.EffectiveType())
}
want := strings.Fields(`
object
array
number
string
null
boolean
array
array
value
`)
if !reflect.DeepEqual(want, got) {
t.Errorf("\ngot:\n\t%v\nwant:\n\t%v", got, want)
}
}

View File

@@ -1,76 +0,0 @@
# This file holds tests for JSON schema to EBNF grammar conversions.
#
# The format is a JSON schema, followed by the expected EBNF grammar. Each test
# MAY be preceded by a comment that describes the test (e.g. the test name), followed by
# the JSON schema and the expected EBNF grammar. If no comment is present, the test
# name the tests number in the file (e.g. "#0", "#1", etc.)
#
# Blank lines signify the end or start of a new test. Comments can be added
# anywhere in the file, but they must be preceded by a '#' character and start at
# the beginning of the line.
# default
{}
root ::= value;
{"properties": {}}
root ::= value;
# array
{"properties": {"a": {"type": "array", "items": {"type": "string"}}}}
root_0_tuple_0 ::= string;
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root ::= "{" "a" ":" root_0 "}";
# array with nested array
{"type": "array", "items": {"type": "array", "items": {"type": "string"}}}
root_tuple_0_tuple_0 ::= string;
root_tuple_0 ::= "[" ( root_tuple_0_tuple_0 )* "]";
root ::= "[" ( root_tuple_0 )* "]";
# object
{"properties": {"e": {}}}
root_0 ::= value;
root ::= "{" "e" ":" root_0 "}";
# object with nested object
{"properties": {"o": {"type": "object", "properties": {"e": {}}}}}
root_0_0 ::= value;
root_0 ::= "{" "e" ":" root_0_0 "}";
root ::= "{" "o" ":" root_0 "}";
# boolean
{"type": "boolean"}
root ::= boolean;
# number
{"properties": {"n": {"type": "number", "minimum": 123, "maximum": 4567}}}
root_0 ::= number;
root ::= "{" "n" ":" root_0 "}";
# string
{"type": "string"}
root ::= string;
# string with enum
{"type": "string", "enum": ["a", "b", "c"]}
root ::= ( "\"a\"" "|" "\"b\"" "|" "\"c\"" );
# spaces in key
{"properties": {"a b": {}}}
root_0 ::= value;
root ::= "{" "a b" ":" root_0 "}";
# issue7978
{ "type": "object", "properties": { "steps": { "type": "array", "items": { "type": "object", "properties": { "explanation": { "type": "string" }, "output": { "type": "string" } }, "required": [ "explanation", "output" ], "additionalProperties": false } }, "final_answer": { "type": "string" } }, "required": [ "steps", "final_answer" ], "additionalProperties": false }
root_0_tuple_0_0 ::= string;
root_0_tuple_0_1 ::= string;
root_0_tuple_0 ::= "{" "explanation" ":" root_0_tuple_0_0 "," "output" ":" root_0_tuple_0_1 "}";
root_0 ::= "[" ( root_0_tuple_0 )* "]";
root_1 ::= string;
root ::= "{" "steps" ":" root_0 "," "final_answer" ":" root_1 "}";
# !! # special characters in key
# !! {"properties": {"a!b": {}}}
# !! !invalid character '!' in key
# !!

View File

@@ -119,10 +119,10 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
}
var cacheSize int
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize)+maxBatch {
if c.windowSize == math.MaxInt32 || capacity < int(c.windowSize) {
cacheSize = maxSequences * capacity
} else {
cacheSize = maxSequences * (int(c.windowSize) + maxBatch)
cacheSize = (maxSequences * int(c.windowSize)) + maxBatch
}
cacheSize = roundUp(cacheSize, c.config.CachePadding)
c.cells = make([]cacheCell, cacheSize)

View File

@@ -362,7 +362,6 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
}
func (c *testContext) Input() ml.Context { return c }
func (c *testContext) Output() ml.Context { return c }
func (c *testContext) Layer(int) ml.Context { return c }
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
@@ -463,7 +462,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
panic("not implemented")
}
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
panic("not implemented")
}

View File

@@ -166,6 +166,10 @@ func (c *Context) KvCacheDefrag() {
C.llama_kv_cache_defrag(c.c)
}
func (c *Context) KvCacheCanShift() bool {
return bool(C.llama_kv_cache_can_shift(c.c))
}
// Get the embeddings for a sequence id
func (c *Context) GetEmbeddingsSeq(seqId int) []float32 {
e := unsafe.Pointer(C.llama_get_embeddings_seq(c.c, C.int(seqId)))

View File

@@ -0,0 +1,103 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Saman <saman.khatir@amd.com>
Date: Wed, 19 Mar 2025 14:02:26 -0700
Subject: [PATCH] add rdna4 support
---
ggml/src/ggml-cuda/common.cuh | 6 ++++--
ggml/src/ggml-cuda/mmq.cu | 2 +-
ggml/src/ggml-cuda/mmq.cuh | 4 ++--
ggml/src/ggml-cuda/mmvq.cu | 4 ++--
ggml/src/ggml-cuda/vendors/hip.h | 4 ++++
5 files changed, 13 insertions(+), 7 deletions(-)
diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh
index adf0d3ec..b24593fc 100644
--- a/ggml/src/ggml-cuda/common.cuh
+++ b/ggml/src/ggml-cuda/common.cuh
@@ -61,11 +61,13 @@
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
-#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
+#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
+#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
c = __builtin_amdgcn_sdot4(a, b, c, false);
-#elif defined(RDNA3)
+#elif defined(RDNA3) || defined(RDNA4)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu
index 10f2ebb1..933d945c 100644
--- a/ggml/src/ggml-cuda/mmq.cu
+++ b/ggml/src/ggml-cuda/mmq.cu
@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
- return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+ return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh
index 0451c65f..66ce2bc9 100644
--- a/ggml/src/ggml-cuda/mmq.cuh
+++ b/ggml/src/ggml-cuda/mmq.cuh
@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
__launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
+#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)
diff --git a/ggml/src/ggml-cuda/mmvq.cu b/ggml/src/ggml-cuda/mmvq.cu
index 4fb466ca..23ae7abc 100644
--- a/ggml/src/ggml-cuda/mmvq.cu
+++ b/ggml/src/ggml-cuda/mmvq.cu
@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
-#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
+#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
-#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
+#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 81964611..a62544b5 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -150,6 +150,10 @@
#define CDNA
#endif
+#if defined(__gfx1200__) || defined(__gfx1201__)
+#define RDNA4
+#endif
+
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3

View File

@@ -15,12 +15,12 @@ import (
)
// This algorithm looks for a complete fit to determine if we need to unload other models
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options) (bool, uint64) {
func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
// Split up the GPUs by type and try them
var estimatedVRAM uint64
for _, gpus := range allGpus.ByLibrary() {
var layerCount int
estimate := EstimateGPULayers(gpus, f, projectors, opts)
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
if opts.NumGPU < 0 {
if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) {
@@ -71,7 +71,7 @@ type MemoryEstimate struct {
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
// The GPUs provided must all be the same Library
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options) MemoryEstimate {
func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
// Graph size for a partial offload, applies to all GPUs
var graphPartialOffload uint64
@@ -137,13 +137,19 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
}
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), kvct)
kv, graphPartialOffload, graphFullOffload := f.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)), numParallel, kvct)
// KV is proportional to the number of layers
layerSize += kv / f.KV().BlockCount()
if len(kv) > 0 {
layerSize += kv[0]
}
var kvTotal uint64
for _, kvLayer := range kv {
kvTotal += kvLayer
}
if graphPartialOffload == 0 {
graphPartialOffload = f.KV().GQA() * kv / 6
graphPartialOffload = f.KV().GQA() * kvTotal / 6
}
if graphFullOffload == 0 {
graphFullOffload = graphPartialOffload
@@ -217,7 +223,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
// Some models have inconsistent layer sizes
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
layerSize = blk.Size()
layerSize += kv / f.KV().BlockCount()
layerSize += kv[i]
memoryWeights += blk.Size()
}
@@ -315,7 +321,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
layersRequested: opts.NumGPU,
layersModel: int(f.KV().BlockCount()) + 1,
availableList: availableList,
kv: kv,
kv: kvTotal,
allocationsList: allocationsList,
memoryWeights: memoryWeights,
memoryLayerOutput: memoryLayerOutput,
@@ -374,7 +380,7 @@ func (m MemoryEstimate) LogValue() slog.Value {
slog.Group(
"weights",
// memory of the weights
"total", format.HumanBytes2(m.memoryWeights),
"total", format.HumanBytes2(m.memoryWeights+m.memoryLayerOutput),
// memory of repeating layers
"repeating", format.HumanBytes2(m.memoryWeights),
// memory of non-repeating layers

View File

@@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) {
projectors := []string{}
opts := api.DefaultOptions()
t.Run("cpu", func(t *testing.T) {
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, 0, estimate.Layers)
assert.Equal(t, uint64(0), estimate.Graph)
})
@@ -112,7 +112,7 @@ func TestEstimateGPULayers(t *testing.T) {
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1)
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
var layerSums uint64

View File

@@ -29,7 +29,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/grammar"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
@@ -110,7 +109,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
gpus = discover.GetCPUInfo()
}
estimate := EstimateGPULayers(gpus, f, projectors, opts)
estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel)
if len(gpus) > 1 || gpus[0].Library != "cpu" {
switch {
case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory:
@@ -701,9 +700,9 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
}
// User provided a JSON schema
g, err := grammar.FromSchema(nil, req.Format)
if err != nil {
return fmt.Errorf("invalid JSON schema in format: %w", err)
g := llama.SchemaToGrammar(req.Format)
if g == nil {
return fmt.Errorf("invalid JSON schema in format")
}
req.Grammar = string(g)
}
@@ -714,11 +713,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
req.Options = &opts
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
}
if err := s.sem.Acquire(ctx, 1); err != nil {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
@@ -733,6 +727,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
if req.Options.NumPredict < 0 || req.Options.NumPredict > 10*s.options.NumCtx {
req.Options.NumPredict = 10 * s.options.NumCtx
}
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
if err != nil {

View File

@@ -110,16 +110,61 @@ type Context interface {
MaxGraphNodes() int
Close()
// Input returns a context appropriate for creating input tensors
// Input returns a context appropriate for creating tensors that are
// inputs to the model (which includes things like output locations)
Input() Context
// Output returns a context appropriate for creating output tensors
Output() Context
// Layer returns a context appropriate for creating intermediate tensors
Layer(int) Context
}
// RopeType represents different RoPE (Rotary Position Embedding) implementation types
type RopeType int
// Available RoPE implementation types
const (
RopeTypeNormal RopeType = iota // Standard RoPE implementation
RopeTypeNeox // NeoX-style RoPE implementation
RopeTypeMRoPE // Multimodal RoPE implementation
RopeTypeVision // Vision-specific RoPE implementation
)
type YarnConfig struct {
YarnCtxTrain int // Context size used during training (for YaRN scaling)
YarnExtFactor float32 // Extension factor for YaRN
YarnAttnFactor float32 // Attention scaling factor for YaRN
YarnBetaFast float32 // Fast decay parameter for YaRN
YarnBetaSlow float32 // Slow decay parameter for YaRN
}
// DefaultYarnConfig returns a default configuration for YaRN (Yet Another Rope Extension)
func DefaultYarnConfig(nCtx int32) *YarnConfig {
return &YarnConfig{
YarnCtxTrain: int(nCtx),
YarnExtFactor: 0.0,
YarnAttnFactor: 1.0,
YarnBetaFast: 32.0,
YarnBetaSlow: 1.0,
}
}
// RoPEConfig holds configuration for Rotary Position Embedding
type RoPEConfig struct {
// Dim is the dimensionality for applying rotary embeddings
Dim uint32
// Type specifies the RoPE implementation variant
Type RopeType
// Base controls frequency decay for the embeddings
Base float32
// Scale allows scaling the effective context length
Scale float32
*YarnConfig
}
type Tensor interface {
Dim(n int) int
Stride(n int) int
@@ -143,7 +188,7 @@ type Tensor interface {
AvgPool2D(ctx Context, k, s int, p float32) Tensor
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
RoPE(ctx Context, positionIDs, ropeFactors Tensor, config RoPEConfig) Tensor
Tanh(ctx Context) Tensor
GELU(ctx Context) Tensor

View File

@@ -48,9 +48,6 @@ type Backend struct {
// input is the backend used for inputs
input *C.struct_ggml_backend_buffer_type
// output is the backend used for outputs
output *C.struct_ggml_backend_buffer_type
// layers is the backend used for repeating layers
layers map[int]*C.struct_ggml_backend_buffer_type
@@ -400,8 +397,7 @@ func New(ctx context.Context, r *os.File, params ml.BackendParams) (ml.Backend,
C.size_t(maxGraphNodes),
C._Bool(len(gpus) > 1 && slices.Contains(gpus, output.d)),
),
input: deviceBufferTypes[input.d],
output: deviceBufferTypes[output.d],
input: deviceBufferTypes[input.d],
layers: func() map[int]*C.struct_ggml_backend_buffer_type {
m := make(map[int]*C.struct_ggml_backend_buffer_type)
for i, layer := range layers {
@@ -482,19 +478,6 @@ func (c Context) Input() ml.Context {
return &c
}
func (c Context) Output() ml.Context {
if c.b.output != nil {
return &Context{
b: c.b,
ctx: c.ctx,
buft: c.b.output,
maxGraphNodes: c.maxGraphNodes,
}
}
return &c
}
func (c Context) Layer(i int) ml.Context {
if buft, ok := c.b.layers[i]; ok {
return &Context{
@@ -924,6 +907,8 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
}
// GGML RoPE types
// These are the types used in the C implementation of RoPE
const (
ropeTypeNorm C.int = 0
ropeTypeNeox C.int = 2
@@ -931,7 +916,8 @@ const (
ropeTypeVision C.int = 24
)
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
// RoPE applies Rotary Position Embeddings to the tensor
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, config ml.RoPEConfig) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{b: t.b}
}
@@ -941,19 +927,41 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
dequant = C.ggml_cast(ctx.(*Context).ctx, t.t, C.GGML_TYPE_F32)
}
if config.YarnConfig == nil {
config.YarnConfig = ml.DefaultYarnConfig(131072) // 131072 is the default for LLaMA, so it is common at the time of writing
}
// Map Go RopeType to C implementation constants
var ropeTypeC C.int
switch config.Type {
case ml.RopeTypeNormal:
ropeTypeC = ropeTypeNorm
case ml.RopeTypeNeox:
ropeTypeC = ropeTypeNeox
case ml.RopeTypeMRoPE:
ropeTypeC = ropeTypeMrope
case ml.RopeTypeVision:
ropeTypeC = ropeTypeVision
default:
ropeTypeC = ropeTypeNorm
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
C.int(ropeType),
131072, // YaRN n_ctx_train
C.float(ropeBase),
C.float(ropeScale),
0., // YaRN ext_factor
1., // YaRN attn_factor
32., // YaRN beta_fast
1., // YaRN beta_slow
ctx.(*Context).ctx,
dequant,
positionIDs.(*Tensor).t,
ropeFactors.(*Tensor).t,
C.int(config.Dim),
ropeTypeC,
C.int(config.YarnCtxTrain),
C.float(config.Base),
C.float(config.Scale),
C.float(config.YarnExtFactor),
C.float(config.YarnAttnFactor),
C.float(config.YarnBetaFast),
C.float(config.YarnBetaSlow),
),
}
}

View File

@@ -61,11 +61,13 @@
#define GGML_CUDA_CC_RDNA1 (GGML_CUDA_CC_OFFSET_AMD + 0x1010) // RX 5000
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_RDNA4 (GGML_CUDA_CC_OFFSET_AMD + 0x1200) // RX 9000
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3 && cc < GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_RDNA4(cc) (cc >= GGML_CUDA_CC_RDNA4)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
@@ -386,7 +388,7 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(RDNA2)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
#elif defined(RDNA3) || defined(RDNA4)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;

View File

@@ -149,5 +149,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA4(cc) && !GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

View File

@@ -2577,9 +2577,9 @@ static __device__ void mul_mat_q_process_tile(
template <ggml_type type, int mmq_x, int nwarps, bool need_check>
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#if defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
__launch_bounds__(WARP_SIZE*nwarps, 2)
#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#endif // defined(RDNA4) || defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
#else
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
__launch_bounds__(WARP_SIZE*nwarps, 1)

View File

@@ -62,13 +62,13 @@ static __global__ void mul_mat_vec_q(
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;

View File

@@ -150,6 +150,10 @@
#define CDNA
#endif
#if defined(__gfx1200__) || defined(__gfx1201__)
#define RDNA4
#endif
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3

View File

@@ -13,10 +13,11 @@ import (
type Options struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeBase, ropeScale float32
eps float32
attnLogitSoftcap float32
finalLogitSoftcap float32
largeModelScaling bool
ropeConfig ml.RoPEConfig
}
type Model struct {
@@ -55,10 +56,15 @@ func New(c ml.Config) (model.Model, error) {
attnKeyLen: int(c.Uint("attention.key_length")),
attnValLen: int(c.Uint("attention.value_length")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base", 10000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
attnLogitSoftcap: c.Float("attn_logit_softcapping"),
finalLogitSoftcap: c.Float("final_logit_softcapping"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
},
}
@@ -78,11 +84,10 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -92,7 +97,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, nil, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -122,7 +127,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
return key.RoPE(ctx, shift, nil, m.ropeConfig), nil
}
type MLP struct {

View File

@@ -13,9 +13,11 @@ import (
type TextOptions struct {
hiddenSize, numHeads, numKVHeads int
attnKeyLen, attnValLen int
eps, ropeScale float32
ropeLocalBase, ropeGlobalBase float32
eps float32
largeModelScaling bool
ropeLocalConfig ml.RoPEConfig
ropeGlobalConfig ml.RoPEConfig
}
type TextModel struct {
@@ -56,15 +58,27 @@ func newTextModel(c ml.Config) *TextModel {
),
Layers: make([]TextLayer, numBlocks),
TextOptions: &TextOptions{
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalBase: c.Float("rope.local.freq_base", 10000.0),
ropeGlobalBase: c.Float("rope.global.freq_base", 1000000.0),
ropeScale: c.Float("rope.freq_scale", 1.0),
hiddenSize: int(c.Uint("embedding_length")),
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
attnKeyLen: int(c.Uint("attention.key_length", 256)),
attnValLen: int(c.Uint("attention.value_length", 256)),
eps: c.Float("attention.layer_norm_rms_epsilon", 1e-06),
ropeLocalConfig: ml.RoPEConfig{
Base: c.Float("rope.local.freq_base", 10000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
ropeGlobalConfig: ml.RoPEConfig{
Base: c.Float("rope.global.freq_base", 1000000.0),
Scale: c.Float("rope.freq_scale", 1.0),
Dim: c.Uint("attention.key_length", 256),
Type: ml.RopeTypeNeox,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
},
}
@@ -86,17 +100,16 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
ropeType := uint32(2)
ropeBase := opts.ropeLocalBase
ropeConfig := opts.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = opts.ropeGlobalBase
ropeConfig = opts.ropeGlobalConfig
}
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, opts.attnKeyLen, opts.numHeads, batchSize)
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, nil, ropeConfig)
if opts.largeModelScaling {
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
@@ -107,7 +120,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, opts.attnKeyLen, opts.numKVHeads, batchSize)
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, nil, ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, opts.attnValLen, opts.numKVHeads, batchSize)
@@ -120,12 +133,12 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
ropeBase := m.TextOptions.ropeLocalBase
ropeConfig := m.ropeLocalConfig
if (layer+1)%gemmaGlobalCacheCount == 0 {
ropeBase = m.TextOptions.ropeGlobalBase
ropeConfig = m.ropeGlobalConfig
}
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
return key.RoPE(ctx, shift, nil, ropeConfig), nil
}
type TextMLP struct {

View File

@@ -14,8 +14,8 @@ import (
type Options struct {
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
eps float32
ropeConfig ml.RoPEConfig
}
type Model struct {
@@ -54,9 +54,13 @@ func New(c ml.Config) (model.Model, error) {
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
},
}
@@ -76,15 +80,14 @@ type SelfAttention struct {
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
q := sa.Query.Forward(ctx, hiddenState)
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
k := sa.Key.Forward(ctx, hiddenState)
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeConfig)
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -97,7 +100,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
}
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, m.ropeConfig), nil
}
type MLP struct {

View File

@@ -20,15 +20,14 @@ type TextSelfAttention struct {
func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
batchSize := hiddenState.Dim(1)
headDim := opts.hiddenSize / opts.numHeads
ropeType := uint32(0)
query := sa.Query.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
key := sa.Key.Forward(ctx, hiddenState)
key = key.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeConfig)
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
@@ -43,7 +42,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeConfig), nil
}
return key, nil
@@ -198,8 +197,8 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
type TextModelOptions struct {
hiddenSize, numHeads, numKVHeads int
eps, ropeBase, ropeScale float32
ropeDim uint32
eps float32
ropeConfig ml.RoPEConfig
crossAttentionLayers []uint32
}
@@ -240,10 +239,14 @@ func newTextModel(c ml.Config) *TextModel {
numHeads: int(c.Uint("attention.head_count")),
numKVHeads: int(c.Uint("attention.head_count_kv")),
eps: c.Float("attention.layer_norm_rms_epsilon"),
ropeBase: c.Float("rope.freq_base"),
ropeScale: c.Float("rope.freq_scale", 1),
ropeDim: c.Uint("rope.dimension_count"),
crossAttentionLayers: c.Uints("attention.cross_attention_layers"),
ropeConfig: ml.RoPEConfig{
Base: c.Float("rope.freq_base"),
Scale: c.Float("rope.freq_scale", 1),
Dim: c.Uint("rope.dimension_count"),
Type: ml.RopeTypeNormal,
YarnConfig: ml.DefaultYarnConfig(int32(c.Uint("context_length", 131072))),
},
},
}
}

View File

@@ -32,7 +32,6 @@ type TextProcessor interface {
Encode(s string, addSpecial bool) ([]int32, error)
Decode([]int32) (string, error)
Is(int32, Special) bool
Vocab() *Vocabulary
}
type Vocabulary struct {

View File

@@ -53,10 +53,6 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool {
return spm.vocab.Is(id, special)
}
func (spm SentencePieceModel) Vocab() *Vocabulary {
return spm.vocab
}
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
return func(yield func(string) bool) {
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {

View File

@@ -213,8 +213,16 @@ func (c *InputCache) ShiftDiscard(inputLen int, numKeep int) int {
return discard
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
type ErrReprocessInputs struct {
Inputs []input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// ShiftCacheSlot frees up space in the KV cache by deleting the oldest half of history
// and shifting the newest half into that space (saving numKeep inputs at the beginning).
//
// Assumes that at least 1 entry can be freed up by shifting (i.e. numKeep < numCtx)
func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
@@ -222,7 +230,8 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
return fmt.Errorf("unable to shift context - keep exceeds context (keep: %v context: %v)", numKeep, c.numCtx)
}
discard := c.ShiftDiscard(len(slot.Inputs), numKeep)
inputLen := len(slot.Inputs)
discard := c.ShiftDiscard(inputLen, numKeep)
if discard <= 0 {
return nil
@@ -231,16 +240,42 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v)", slot.Id, numKeep, discard)
}
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, len(slot.Inputs), -discard)
var shiftFailed bool
for i := numKeep + discard; i < len(slot.Inputs); i++ {
if c.lc.KvCacheCanShift() {
// For models that support shifting, attempt to shift the KV cache
if !c.lc.KvCacheSeqRm(slot.Id, numKeep, numKeep+discard) {
shiftFailed = true
slog.Debug("kv cache removal not supported, clearing cache and returning inputs for reprocessing", "id", slot.Id)
} else {
c.lc.KvCacheSeqAdd(slot.Id, numKeep+discard, inputLen, -discard)
}
} else {
// For models that don't support shifting
shiftFailed = true
slog.Debug("kv cache cannot shift, clearing cache and returning inputs for reprocessing", "id", slot.Id)
}
if shiftFailed {
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Clear the entire KV cache
_ = c.lc.KvCacheSeqRm(slot.Id, 0, -1)
// Reset the slot inputs since we've cleared the cache
slot.Inputs = []input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
// Standard shift succeeded - update input array
for i := numKeep + discard; i < inputLen; i++ {
slot.Inputs[i-discard] = slot.Inputs[i]
}
slot.Inputs = slot.Inputs[:len(slot.Inputs)-discard]
slot.Inputs = slot.Inputs[:inputLen-discard]
return nil
}

View File

@@ -389,7 +389,15 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
if len(seq.pendingInputs) == 0 {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Continue processing as normal
continue
} else {
return err
}
}
} else {
break
@@ -599,7 +607,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -611,6 +619,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, true)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -626,6 +635,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}
@@ -691,7 +701,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting embeddings request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -703,6 +713,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs, false)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -715,6 +726,7 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}

View File

@@ -239,6 +239,14 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
return discard
}
type ErrReprocessInputs struct {
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
return fmt.Sprintf("kv cache shift not supported, inputs need reprocessing (input count: %v)", len(e.Inputs))
}
// Frees up space in the KV cache by deleting the oldest half of history and shifting
// the newest half into that space (saving numKeep inputs at the beginning).
//
@@ -258,11 +266,23 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
slog.Debug("context limit hit - shifting", "id", slot.Id, "limit", c.numCtx, "input", len(slot.Inputs),
"keep", numKeep, "discard", discard)
// TODO (jessegross): KV cache removal can fail for certain types of models
if c.cache != nil {
err := c.cache.Remove(slot.Id, numKeep, numKeep+discard)
if err != nil {
return fmt.Errorf("unable to remove old kv cache entries (id: %v, keep: %v discard: %v): %w", slot.Id, numKeep, discard, err)
slog.Debug("kv cache removal unsupported, clearing cache and returning inputs for reprocessing",
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, -1)
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}
}
}

View File

@@ -1,10 +1,13 @@
package ollamarunner
import (
"errors"
"fmt"
"image"
"testing"
"time"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/input"
)
@@ -425,3 +428,91 @@ func TestLoadCacheSlot(t *testing.T) {
})
}
}
// Mock implementation of the Cache interface
type mockCache struct {
shouldFail bool
}
// Implement only the methods needed for the test
func (m *mockCache) Remove(seq int, beginIndex, endIndex int32) error {
if m.shouldFail {
return fmt.Errorf("mock cache removal error")
}
return nil
}
// Stub implementations for other interface methods
func (m *mockCache) SetLayer(layer int) {}
func (m *mockCache) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { return nil, nil, nil }
func (m *mockCache) Put(ctx ml.Context, key, value ml.Tensor) {}
func (m *mockCache) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity, maxBatch int) {}
func (m *mockCache) Close() {}
func (m *mockCache) StartForward(ctx ml.Context, batch input.Batch) error { return nil }
func (m *mockCache) CopyPrefix(srcSeq, dstSeq int, len int32) {}
func (m *mockCache) SetConfig(ml.CacheConfig) {}
func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
wantInputsLen int
}{
{
name: "Normal shift",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
wantInputsLen: 6, // After discarding 4 tokens
},
{
name: "Cache removal fails",
numCtx: 10,
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
wantInputsLen: 0, // Original inputs should be cleared
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mock := &mockCache{shouldFail: tt.cacheErr}
c := InputCache{
numCtx: tt.numCtx,
cache: mock,
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)
err := c.ShiftCacheSlot(slot, tt.numKeep)
if tt.wantErr != nil {
if err == nil {
t.Errorf("Expected error but got nil")
return
}
if !errors.As(err, &tt.wantErr) {
t.Errorf("Expected error of type %T but got %T: %v", tt.wantErr, err, err)
}
} else if err != nil {
t.Errorf("Unexpected error: %v", err)
}
if len(slot.Inputs) != tt.wantInputsLen {
t.Errorf("Slot inputs length after operation: got %v, want %v", len(slot.Inputs), tt.wantInputsLen)
}
})
}
}

View File

@@ -267,6 +267,9 @@ type Server struct {
// KV cache
cache *InputCache
// next sequence for prompt processing to avoid starvation
nextSeq int
// multimodalHash generates hashes for comparing equality
// of non-text data
multimodalHash maphash.Hash
@@ -351,14 +354,19 @@ func (s *Server) processBatch() error {
var batchInputs []int32
var batch input.Batch
for i, seq := range s.seqs {
resumeSeq := -1
seqIdx := s.nextSeq - 1
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(i, "limit")
s.removeSequence(seqIdx, "limit")
continue
}
@@ -369,16 +377,23 @@ func (s *Server) processBatch() error {
batchSize := s.batchSize
for j, inp := range seq.inputs {
for i, inp := range seq.inputs {
// If we are required to put following inputs into a single batch then extend the
// batch size. Since we are only extending the size the minimum amount possible, this
// will cause a break if we have pending inputs.
// will cause a break if we have existing inputs.
minBatch := 1 + inp.SameBatch
if minBatch > batchSize {
batchSize = minBatch
}
if len(seq.pendingInputs)+minBatch > batchSize {
// Stop if the required batch would put us over the total batch size (including tokens
// added by other sequences). If we haven't been able to add anything yet then pick up
// here again for the next batch to avoid starvation, though we can opportunistically
// check if other sequences can still squeeze something in.
if len(batchInputs)+minBatch > batchSize {
if len(seq.pendingInputs) == 0 && resumeSeq == -1 {
resumeSeq = seqIdx
}
break
}
@@ -392,7 +407,15 @@ func (s *Server) processBatch() error {
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
return err
var reprocess *ErrReprocessInputs
if errors.As(err, &reprocess) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
continue
} else {
return err
}
}
}
@@ -405,7 +428,7 @@ func (s *Server) processBatch() error {
batch.Sequences = append(batch.Sequences, seq.cache.Id)
seq.iBatch = len(batch.Outputs)
if j+1 == len(seq.inputs) {
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
seq.pendingInputs = append(seq.pendingInputs, inp)
@@ -414,6 +437,12 @@ func (s *Server) processBatch() error {
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
s.nextSeq = seqIdx + 1
}
if len(batchInputs) == 0 {
return nil
}
@@ -468,20 +497,6 @@ func (s *Server) processBatch() error {
return fmt.Errorf("failed to sample token: %w", err)
}
if seq.sampler.JSONSampler != nil {
_, err = seq.sampler.JSONSampler.UpdateState([]int32{token})
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
if seq.sampler.PythonSampler != nil {
err = seq.sampler.PythonSampler.UpdateState(token)
if err != nil {
return fmt.Errorf("failed to update state: %w", err)
}
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
@@ -576,21 +591,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
}
}
// jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil)
// if err != nil {
// http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError)
// return
// }
// jsonSampler = nil
pythonSampler := &sample.PythonSampler{}
functions := []sample.PythonFunction{
{
Name: "add_two_strings",
Args: []string{"s1", "s2"},
Types: []string{"string", "string"},
},
}
pythonSampler.Init(functions, s.model.(model.TextProcessor))
sampler := sample.NewSampler(
req.Options.Temperature,
req.Options.TopK,
@@ -598,9 +598,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
req.Options.MinP,
req.Options.Seed,
grammar,
nil,
pythonSampler,
// nil,
)
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
@@ -620,7 +617,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
if errors.Is(err, context.Canceled) {
slog.Info("aborting completion request due to client closing the connection")
} else {
slog.Error("Failed to acquire semaphore", "error", err)
http.Error(w, fmt.Sprintf("Failed to acquire semaphore: %v", err), http.StatusInternalServerError)
}
return
}
@@ -632,6 +629,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
seq.cache, seq.inputs, err = s.cache.LoadCacheSlot(seq.inputs)
if err != nil {
s.mu.Unlock()
s.seqsSem.Release(1)
http.Error(w, fmt.Sprintf("Failed to load cache: %v", err), http.StatusInternalServerError)
return
}
@@ -645,6 +643,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
s.mu.Unlock()
if !found {
s.seqsSem.Release(1)
http.Error(w, "could not find an available sequence", http.StatusInternalServerError)
return
}

View File

@@ -1,53 +0,0 @@
package sample
var DefaultGrammar = map[string]string{
"unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`,
"null": `"null"`,
"object": `"{" (kv ("," kv)*)? "}"`,
"array": `"[" (value ("," value)*)? "]"`,
"kv": `string ":" value`,
"integer": `"0" | [1-9] [0-9]*`,
"number": `"-"? integer frac? exp?`,
"frac": `"." [0-9]+`,
"exp": `("e" | "E") ("+" | "-") [0-9]+`,
"string": `"\"" char* "\""`,
"escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`,
"char": `[^"\\] | escape`,
"space": `(" " | "\t" | "\n" | "\r")*`,
"hex": `[0-9] | [a-f] | [A-F]`,
"boolean": `"true" | "false"`,
"value": `object | array | string | number | boolean | "null"`,
}
const jsonString = `object | array`
type StateMachine struct {
states map[rune]State
}
type State struct {
NextStates []string
// bitmask?
Mask []bool
IsTerminal bool
}
func NewStateMachine(grammar map[string]string, startRule string) *StateMachine {
states := make(map[rune]State)
var cumu string
flag := false
for _, r := range startRule {
if r == '"' {
flag = !flag
}
if flag {
cumu += string(r)
}
}
sm := &StateMachine{
states: states,
}
return sm
}

View File

@@ -1,138 +0,0 @@
package sample
import (
"testing"
)
func TestGrammarParsing(t *testing.T) {
tests := []struct {
name string
grammar map[string]string
startRule string
input string
want bool
}{
{
name: "simple object",
grammar: map[string]string{
"object": `"{" "}"`,
},
startRule: "object",
input: "{}",
want: true,
},
{
name: "simple array",
grammar: map[string]string{
"array": `"[" "]"`,
},
startRule: "array",
input: "[]",
want: true,
},
{
name: "character class",
grammar: map[string]string{
"digit": `[0-9]`,
},
startRule: "digit",
input: "5",
want: true,
},
{
name: "alternation",
grammar: map[string]string{
"bool": `"true" | "false"`,
},
startRule: "bool",
input: "true",
want: true,
},
{
name: "repetition",
grammar: map[string]string{
"digits": `[0-9]+`,
},
startRule: "digits",
input: "123",
want: true,
},
{
name: "nested rules",
grammar: map[string]string{
"value": `object | array`,
"object": `"{" "}"`,
"array": `"[" "]"`,
},
startRule: "value",
input: "{}",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewParser(tt.grammar)
machine, err := parser.Parse(tt.startRule)
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
got, err := matcher.Match(tt.input)
if err != nil {
t.Fatalf("Match() error = %v", err)
}
if got != tt.want {
t.Errorf("Match() = %v, want %v", got, tt.want)
}
})
}
}
func TestJSONGrammar(t *testing.T) {
tests := []struct {
name string
input string
want bool
}{
{"empty object", "{}", true},
{"empty array", "[]", true},
{"simple string", `"hello"`, true},
{"simple number", "123", true},
{"simple boolean", "true", true},
{"simple null", "null", true},
{"object with string", `{"key": "value"}`, true},
{"array with numbers", "[1, 2, 3]", true},
{"nested object", `{"obj": {"key": "value"}}`, true},
{"nested array", `[1, [2, 3], 4]`, true},
{"invalid object", "{", false},
{"invalid array", "[1, 2", false},
{"invalid string", `"hello`, false},
}
parser := NewParser(DefaultGrammar)
machine, err := parser.Parse("value")
if err != nil {
t.Fatalf("Parse() error = %v", err)
}
matcher := NewMatcher(machine)
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := matcher.Match(tt.input)
if tt.want {
if err != nil {
t.Errorf("Match() error = %v", err)
}
if !got {
t.Errorf("Match() = false, want true")
}
} else {
if err == nil && got {
t.Errorf("Match() = true, want false")
}
}
})
}
}

View File

@@ -1,160 +0,0 @@
package sample
import (
"fmt"
)
type JSONState int
const (
StateStart JSONState = iota
StateInObject
StateInObjectKey
StateInStructuredKey
StateInStructuredValue
StateNewline
StateTab
StateSpace
StateInString
StateInInt
StateInFloat
StateInBool
StateInNull
StateInColon
StateInComma
StateInTab
StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace
StateInList
StateInListComma
StateInValue
StateInValueEnd
StateInListEnd
StateInListObjectEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateInObjectEnd
StateTransitioningToTerminate
StateInListStartJSON
)
var JSONStates = []JSONState{
StateStart,
StateInObject,
StateInObjectKey,
StateInStructuredKey,
StateInStructuredValue,
StateNewline,
StateTab,
StateSpace,
StateInString,
StateInInt,
StateInFloat,
StateInBool,
StateInNull,
StateInColon,
StateInComma,
StateInTab,
StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace,
StateInListStartJSON,
StateInList,
StateInListComma,
StateInValue,
StateInValueEnd,
StateInListEnd,
StateInListObjectEnd,
StateInNewline,
StateInNumber,
StateInNumberEnd,
StateInStringEnd,
StateInObjectKeyEnd,
StateTerminate,
StateInObjectEnd,
StateTransitioningToTerminate,
}
func (s JSONState) String() string {
switch s {
case StateStart:
return "StateStart"
case StateInObject:
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInStructuredKey:
return "StateInStructuredKey"
case StateInStructuredValue:
return "StateInStructuredValue"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpaceToValue:
return "StateInSpaceToValue"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateInValue:
return "StateInValue"
case StateInValueEnd:
return "StateInValueEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInListObjectEnd:
return "StateInListObjectEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateTransitioningToTerminate:
return "StateTransitioningToTerminate"
case StateInListStartJSON:
return "StateInListStartJSON"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
}

View File

@@ -1,327 +0,0 @@
package sample
import (
"fmt"
"slices"
"github.com/ollama/ollama/model"
)
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var (
intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
)
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDA struct {
State JSONState
TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDA
}
func NewPDANode(state JSONState) *PDA {
return &PDA{
State: state,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
}
type PDAGraphBuilder struct {
proc model.TextProcessor
decodedToks []string
stateToNodeMap map[JSONState]*PDA
tokenToStatesMap map[int32][]JSONState
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state)
}
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
// TODO: update naming here - and revisit values
stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON]
stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject]
// new line end value
// stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd]
stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon]
stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
// Leads to a value
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap)
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue]
// Values
// string node
stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString]
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
// stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
// stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// early end
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node
for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
}
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma
// should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
// stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// todo: review this space transition
// stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
// space end value
// stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
}
func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
// TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
}
func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap {
// if node.State == StateInObjectKey {
// if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 {
// b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode
// fmt.Println("copying string mask to object key mask")
// }
// }
if err := b.CreateMask(node); err != nil {
return err
}
}
return nil
}
func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error {
// TODO: make can be somewhere else too
b.tokenToStatesMap = make(map[int32][]JSONState)
for i, t := range b.decodedToks {
for _, r := range t {
if r == '"' {
b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString)
}
}
}
return nil
}
// TODO: the mask for obj key and string should be the same?
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range b.decodedToks {
token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}

View File

@@ -1,264 +0,0 @@
package sample
import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
// TODO: safety in case of invalid json
// TODO: partial JSON matching?
// TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
// TODO: minimize number of fwd passes if there is only one match
// TODO: greedy sample initially and then backtrack if no match
type PushdownSampler struct {
PDAGraphBuilder
curNode *PDA
braceStack []rune
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now()
fmt.Println("--------------------------------")
fmt.Println("PDA sampler")
fmt.Println("--------------------------------")
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
gb := &PDAGraphBuilder{
proc: proc,
decodedToks: decodedToks,
}
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{
curNode: gb.stateToNodeMap[StateStart],
PDAGraphBuilder: *gb,
braceStack: []rune{},
stateCounter: 0,
}, nil
}
// TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
switch s.curNode.State {
case StateInString:
return s.maskLogits(logits, s.curNode)
case StateInListEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
return forceFinish(s, logits)
}
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
}
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return nil, err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
// Special handling for EOS token in terminate state
if s.curNode.State == StateTerminate {
for _, tokenID := range tokenSlice {
if s.proc.Is(tokenID, model.SpecialEOS) {
return tokenSlice, nil
}
}
}
// flag := -1
// endBraceRunes := []rune{'}', ']'}
for _, r := range mappedString {
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
// // flag = i
// break
// }
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('{') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
if r == rune(']') {
if len(s.braceStack) == 0 {
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
}
top := s.braceStack[len(s.braceStack)-1]
if top != rune('[') {
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
}
// if flag != -1 {
// tokenSlice = tokenSlice[:flag]
// }
// fmt.Println("flag!", flag)
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return nil, fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode.State)
// TODO: add a penalty for staying in the same state too long
if nextNode.State == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = nextNode
fmt.Println("updated curNode state", s.curNode.State)
}
return tokenSlice, nil
}
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
maxLogit := float32(math.Inf(-1))
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float32(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@@ -17,14 +17,12 @@ type token struct {
}
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
JSONSampler *JSONSampler
PythonSampler *PythonSampler
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@@ -32,19 +30,6 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
return -1, errors.New("sample: no logits provided to sample")
}
var err error
if s.JSONSampler != nil {
logits, err = s.JSONSampler.Apply(logits)
if err != nil {
return -1, err
}
}
if s.PythonSampler != nil {
logits, err = s.PythonSampler.ApplyMask(logits)
if err != nil {
return -1, err
}
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -142,7 +127,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -170,14 +155,12 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
JSONSampler: jsonSampler,
PythonSampler: pythonSampler,
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
}
}

View File

@@ -1,299 +0,0 @@
package sample
import (
"fmt"
"log/slog"
"runtime"
"time"
"github.com/ollama/ollama/grammar/jsonschema"
"github.com/ollama/ollama/model"
)
type JSONSampler struct {
schema *jsonschema.Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
slog.Info("NewJSONSampler", "schema", schema)
if proc == nil {
return nil, fmt.Errorf("TextProcessor cannot be nil")
}
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
}
if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
// fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema,
propIdx: -1,
propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler,
}
so.schemaToGraph()
// Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start = time.Now()
runtime.ReadMemStats(&m)
before = m.Alloc
// TODO: still messed up
// TODO: recursion use case
// key masks
for _, prop := range so.schema.Properties {
node := so.propToNodeMap[prop.Name]
// propName -> node
curState := node.State
fromNode := node
so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey {
// there is only one edge
for r, toNode := range fromNode.TransitionEdges {
fmt.Println("rune", r, "edge", toNode.State)
so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r)
curState = toNode.State
fmt.Println("next state", curState)
// TODO: theres an extra gen for " right now
fromNode = toNode
}
}
if curState != StateInColon {
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
}
// so.pdaSampler.CreateMask(fromNode)
fromNode = fromNode.TransitionEdges[' ']
so.pdaSampler.CreateMask(fromNode)
curState = fromNode.State
for _, toNode := range fromNode.TransitionEdges {
fmt.Println("toNode", toNode.State)
}
}
// runtime.ReadMemStats(&m)
// after = m.Alloc
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
// fmt.Println("--------------------------------")
return so, nil
}
func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
// TODO: see if we need to connect these to the JSON graph
// each prop is a key
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
// fmt.Println("runeNode created", runeNode.State)
// fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode
}
// point to end of object key node after all chars are done
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// link to value node
// Create a node for the end of the key (after the closing quote)
stringEndNode := &PDA{
State: StateInStructuredKey,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges['"'] = stringEndNode
prevNode = stringEndNode
// Add transition for colon after key
colonNode := &PDA{
State: StateInColon,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[':'] = colonNode
prevNode = colonNode
// Add transition for space after colon
spaceNode := &PDA{
State: StateInSpaceToValue,
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode.TransitionEdges[' '] = spaceNode
prevNode = spaceNode
value := prop.Type
switch value {
case "object":
fmt.Println("object under key: ", name)
case "array":
fmt.Println("array under key: ", name)
case "string":
fmt.Println("string under key: ", name)
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
case "number":
fmt.Println("number under key: ", name)
for _, r := range validNumberRunes {
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
}
case "boolean":
fmt.Println("boolean under key: ", name)
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
}
// points to start of the key
s.propToNodeMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
// TODO: do values + recursion
}
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State {
// TODO: doesnt account for multi rune case
case StateInObjectKey:
if s.propIdx > len(s.schema.Properties)-1 {
return nil, fmt.Errorf("propIdx out of bounds")
}
// fmt.Println("in object key - structured outputs")
// TODO: this tracking should probably be coming from a stack to track nested objects
// simple case
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
// Will only happen for the last prop - can also be precomputed.
if s.propIdx == len(s.schema.Properties)-1 {
// todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State {
case StateInObjectEnd:
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++
}
}
return s.pdaSampler.Apply(logits)
}
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil {
return nil, err
}
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return tokenSlice, nil
}
switch s.pdaSampler.curNode.State {
case StateInObjectKey:
s.propIdx++
fmt.Println("propIdx", s.propIdx)
prop := s.schema.Properties[s.propIdx]
fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
// TODO: this does not work - mike
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
// if err != nil {
// return nil, err
// }
// fmt.Println("str", str)
return tokenSlice, nil
default:
return tokenSlice, nil
}
}

View File

@@ -1,352 +0,0 @@
package sample
import (
"fmt"
"math"
"slices"
"github.com/ollama/ollama/model"
)
type PythonState int
const (
PythonStateStart PythonState = iota
StateInFunction
StateInFunctionArgs
StateInFunctionArgsType
StateInFunctionEnd
PStateInString
PStateInStringEnd
PStateInNumber
PStateInList
PStateInListEnd
PStateInDict
PStateInDictEnd
PStateInTuple
PStateInTupleEnd
PStateTerminate
)
func (s PythonState) String() string {
switch s {
case PythonStateStart:
return "PythonStateStart"
case StateInFunction:
return "StateInFunction"
case StateInFunctionArgs:
return "StateInFunctionArgs"
case StateInFunctionArgsType:
return "StateInFunctionArgsType"
case StateInFunctionEnd:
return "StateInFunctionEnd"
case PStateInString:
return "PStateInString"
case PStateInStringEnd:
return "PStateInStringEnd"
case PStateInNumber:
return "PStateInNumber"
case PStateInList:
return "PStateInList"
case PStateInListEnd:
return "PStateInListEnd"
case PStateInDict:
return "PStateInDict"
case PStateInDictEnd:
return "PStateInDictEnd"
case PStateInTuple:
return "PStateInTuple"
case PStateInTupleEnd:
return "PStateInTupleEnd"
case PStateTerminate:
return "PStateTerminate"
default:
return fmt.Sprintf("PythonState(%d)", s)
}
}
var PythonStates = []PythonState{
PythonStateStart,
StateInFunction,
StateInFunctionArgs,
StateInFunctionArgsType,
StateInFunctionEnd,
PStateInString,
PStateInStringEnd,
PStateInNumber,
PStateInList,
PStateInListEnd,
PStateInDict,
PStateInDictEnd,
PStateInTuple,
PStateInTupleEnd,
PStateTerminate,
}
type Node struct {
State PythonState
TransitionEdges map[rune]*Node
MaskTokenIDToNode map[int32]*Node
}
func NewNode(state PythonState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[rune]*Node),
MaskTokenIDToNode: make(map[int32]*Node),
}
}
type PythonFunction struct {
Name string
Args []string
Types []string
}
type PythonSampler struct {
stateToNodes map[PythonState]*Node
proc model.TextProcessor
decodedToks []string
curNode *Node
completed int
functions []PythonFunction
}
func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error {
s.proc = proc
s.functions = functions
decodedToks := make([]string, len(proc.Vocab().Values))
for i := range proc.Vocab().Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
s.decodedToks = decodedToks
s.BuildGraph()
for _, function := range functions {
prevNode := s.stateToNodes[PythonStateStart]
for _, r := range function.Name {
nextNode := NewNode(StateInFunction)
prevNode.TransitionEdges[r] = nextNode
if err := s.CreateMask(nextNode); err != nil {
return err
}
fmt.Println("prevNode", prevNode.State)
fmt.Printf("transition edge: %q\n", r)
fmt.Println("nextNode", nextNode.State)
prevNode = nextNode
}
prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs]
s.CreateMask(prevNode)
prevNode = s.stateToNodes[StateInFunctionArgs]
for i, arg := range function.Args {
for _, r := range arg {
nextNode := NewNode(StateInFunctionArgs)
prevNode.TransitionEdges[r] = nextNode
s.CreateMask(prevNode)
prevNode = nextNode
}
prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// prevNode = s.stateToNodes[StateInFunctionArgs]
prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType)
s.CreateMask(prevNode)
prevNode = prevNode.TransitionEdges['=']
switch function.Types[i] {
case "string":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString]
s.CreateMask(prevNode.TransitionEdges['"'])
case "number":
prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber]
s.CreateMask(prevNode.TransitionEdges['"'])
}
}
}
s.curNode = s.stateToNodes[PythonStateStart]
fmt.Println("curNode", s.curNode.State)
fmt.Println("transition edges", s.curNode.TransitionEdges)
if err := s.CreateMask(s.curNode); err != nil {
return err
}
fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode)
for tokenID, node := range s.curNode.MaskTokenIDToNode {
fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State)
}
return nil
}
func (s *PythonSampler) BuildGraph() error {
s.stateToNodes = make(map[PythonState]*Node)
for _, state := range PythonStates {
s.stateToNodes[state] = NewNode(state)
}
for _, state := range s.stateToNodes {
if err := s.CreateMask(state); err != nil {
return err
}
}
// String
s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString]
s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd]
// String end
s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
// s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
// Number
for _, r := range validNumberRunes {
s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber]
}
s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs]
s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs]
return nil
}
func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) {
if s.curNode.State == PStateTerminate {
logits, err := finish(s, logits)
if err != nil {
return nil, err
}
return logits, nil
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
func (s *PythonSampler) UpdateState(token int32) error {
mappedString, err := s.proc.Decode([]int32{token})
if err != nil {
return err
}
fmt.Printf(">>> mappedString: %q\n", mappedString)
if s.curNode.State == PStateTerminate {
if s.proc.Is(token, model.SpecialEOS) {
return nil
}
}
nextNode, ok := s.curNode.MaskTokenIDToNode[token]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
if mappedString == "\"" {
if s.curNode.State == PStateInStringEnd {
s.completed++
}
if s.completed == len(s.functions) {
s.curNode.TransitionEdges[')'] = s.stateToNodes[PStateTerminate]
s.CreateMask(s.curNode)
}
}
s.curNode = nextNode
fmt.Println("curNode", s.curNode.State)
for r, node := range s.curNode.TransitionEdges {
fmt.Printf("transition edge: %q -> %v\n", r, node.State)
}
if err := s.CreateMask(s.curNode); err != nil {
return err
}
return nil
}
func (s *PythonSampler) CreateMask(node *Node) error {
if node == nil {
return fmt.Errorf("node cannot be nil")
}
for i := range s.decodedToks {
token := s.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
curNode, valid = isRValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
if curNode.State == StateInFunction {
// fmt.Println("cm curNode", curNode.State)
// fmt.Println("cm token", s.decodedToks[i])
}
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) {
if consumedSpecialRunes[r] {
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == PStateInString || curNode.State == PStateInStringEnd {
return nil, false
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
// fmt.Println("next node", nextNode)
if specialRune {
if curNode.State == nextNode.State {
return nil, false
}
consumedSpecialRunes[r] = true
}
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return nextNode, true
}
return nil, false
}
func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float32, len(logits))
for i := range maskedLogits {
maskedLogits[i] = float32(math.Inf(-1))
}
// Only update values for valid token IDs from the mask map
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) {
maskedLogits[tokenID] = logits[tokenID]
}
}
return maskedLogits, nil
}
func finish(s *PythonSampler, logits []float32) ([]float32, error) {
for i := range logits {
if s.proc.Is(int32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = float32(math.Inf(-1))
}
}
return logits, nil
}

View File

@@ -29,8 +29,9 @@ import (
const maxRetries = 6
var (
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errMaxRetriesExceeded = errors.New("max retries exceeded")
errPartStalled = errors.New("part stalled")
errMaxRedirectsExceeded = errors.New("maximum redirects exceeded (10) for directURL")
)
var blobDownloadManager sync.Map
@@ -236,7 +237,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errors.New("maximum redirects exceeded (10) for directURL")
return errMaxRedirectsExceeded
}
// if the hostname is the same, allow the redirect

View File

@@ -35,6 +35,7 @@ var (
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
errInsecureProtocol = errors.New("insecure protocol http")
)
type Capability string
@@ -479,7 +480,7 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "retrieving manifest"})
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
return errInsecureProtocol
}
manifest, _, err := GetManifest(mp)
@@ -543,7 +544,7 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
}
if mp.ProtocolScheme == "http" && !regOpts.Insecure {
return errors.New("insecure protocol http")
return errInsecureProtocol
}
fn(api.ProgressResponse{Status: "pulling manifest"})

View File

@@ -421,14 +421,6 @@ func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
return err
}
func canRetry(err error) bool {
var re *Error
if !errors.As(err, &re) {
return false
}
return re.Status >= 500
}
// trackingReader is an io.Reader that tracks the number of bytes read and
// calls the update function with the layer, the number of bytes read.
//
@@ -514,13 +506,40 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
break
}
cacheKey := fmt.Sprintf(
"v1 pull chunksum %s %s %d-%d",
l.Digest,
cs.Digest,
cs.Chunk.Start,
cs.Chunk.End,
)
cacheKeyDigest := blob.DigestFromBytes(cacheKey)
_, err := c.Get(cacheKeyDigest)
if err == nil {
received.Add(cs.Chunk.Size())
t.update(l, cs.Chunk.Size(), ErrCached)
continue
}
wg.Add(1)
g.Go(func() (err error) {
defer func() {
if err == nil {
// Ignore cache key write errors for now. We've already
// reported to trace that the chunk is complete.
//
// Ideally, we should only report completion to trace
// after successful cache commit. This current approach
// works but could trigger unnecessary redownloads if
// the checkpoint key is missing on next pull.
//
// Not incorrect, just suboptimal - fix this in a
// future update.
_ = blob.PutBytes(c, cacheKeyDigest, cacheKey)
received.Add(cs.Chunk.Size())
} else {
err = fmt.Errorf("error downloading %s: %w", cs.Digest.Short(), err)
t.update(l, 0, err)
}
wg.Done()
}()
@@ -563,7 +582,7 @@ func (r *Registry) Pull(ctx context.Context, name string) error {
return err
}
if received.Load() != expected {
return fmt.Errorf("%w: received %d/%d", ErrIncomplete, received.Load(), expected)
return fmt.Errorf("%w: received %d/%d bytes", ErrIncomplete, received.Load(), expected)
}
md := blob.DigestFromBytes(m.Data)
@@ -608,6 +627,30 @@ func (m *Manifest) Layer(d blob.Digest) *Layer {
return nil
}
func (m *Manifest) All() iter.Seq[*Layer] {
return func(yield func(*Layer) bool) {
if !yield(m.Config) {
return
}
for _, l := range m.Layers {
if !yield(l) {
return
}
}
}
}
func (m *Manifest) Size() int64 {
var size int64
if m.Config != nil {
size += m.Config.Size
}
for _, l := range m.Layers {
size += l.Size
}
return size
}
// MarshalJSON implements json.Marshaler.
//
// NOTE: It adds an empty config object to the manifest, which is required by
@@ -750,20 +793,32 @@ func (r *Registry) chunksums(ctx context.Context, name string, l *Layer) iter.Se
return
}
// A chunksums response is a sequence of chunksums in a
// simple, easy to parse line-oriented format.
// The response is a sequence of chunksums.
//
// Example:
// Chunksums are chunks of a larger blob that can be
// downloaded and verified independently.
//
// >> GET /v2/<namespace>/<model>/chunksums/<digest>
// The chunksums endpoint is a GET request that returns a
// sequence of chunksums in the following format:
//
// << HTTP/1.1 200 OK
// << Content-Location: <blobURL>
// <<
// << <digest> <start>-<end>
// << ...
// > GET /v2/<namespace>/<model>/chunksums/<digest>
//
// The blobURL is the URL to download the chunks from.
// < HTTP/1.1 200 OK
// < Content-Location: <blobURL>
// <
// < <digest> <start>-<end>
// < ...
//
// The <blobURL> is the URL to download the chunks from and
// each <digest> is the digest of the chunk, and <start>-<end>
// is the range the chunk in the blob.
//
// Ranges may be used directly in Range headers like
// "bytes=<start>-<end>".
//
// The chunksums returned are guaranteed to be contiguous and
// include all bytes of the layer. If the stream is cut short,
// clients should retry.
chunksumsURL := fmt.Sprintf("%s://%s/v2/%s/%s/chunksums/%s",
scheme,

View File

@@ -9,17 +9,14 @@ import (
"fmt"
"io"
"io/fs"
"math/rand/v2"
"net"
"net/http"
"net/http/httptest"
"os"
"path"
"reflect"
"slices"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/testutil"
@@ -338,15 +335,8 @@ func TestPushCommitRoundtripError(t *testing.T) {
}
}
func checkNotExist(t *testing.T, err error) {
t.Helper()
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v; want fs.ErrNotExist", err)
}
}
func TestRegistryPullInvalidName(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
err := rc.Pull(t.Context(), "://")
if !errors.Is(err, ErrNameInvalid) {
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
@@ -362,197 +352,16 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
}
for _, resp := range cases {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
rc, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
io.WriteString(w, resp)
})
err := rc.Pull(t.Context(), "x")
err := rc.Pull(t.Context(), "http://example.com/a/b")
if !errors.Is(err, ErrManifestInvalid) {
t.Errorf("err = %v; want invalid manifest", err)
}
}
}
func TestRegistryPullNotCached(t *testing.T) {
check := testutil.Checker(t)
var c *blob.DiskCache
var rc *Registry
d := blob.DigestFromBytes("some data")
rc, c = newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
io.WriteString(w, "some data")
return
}
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":9}]}`, d)
})
// Confirm that the layer does not exist locally
_, err := rc.ResolveLocal("model")
checkNotExist(t, err)
_, err = c.Get(d)
checkNotExist(t, err)
err = rc.Pull(t.Context(), "model")
check(err)
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := rc.ResolveLocal("model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
// Confirm successful download
info, err := c.Get(d)
check(err)
if info.Digest != d {
t.Errorf("info.Digest = %v; want %v", info.Digest, d)
}
if info.Size != 9 {
t.Errorf("info.Size = %v; want %v", info.Size, 9)
}
data, err := os.ReadFile(c.GetFile(d))
check(err)
if string(data) != "some data" {
t.Errorf("data = %q; want %q", data, "exists")
}
}
func TestRegistryPullCached(t *testing.T) {
cached := blob.DigestFromBytes("exists")
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/blobs/") {
w.WriteHeader(499) // should not be called
return
}
if strings.Contains(r.URL.Path, "/manifests/") {
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":6}]}`, cached)
}
})
var errs []error
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(d *Layer, n int64, err error) {
t.Logf("update %v %d %v", d, n, err)
reads = append(reads, n)
errs = append(errs, err)
},
})
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
defer cancel()
err := rc.Pull(ctx, "single")
testutil.Check(t, err)
want := []int64{0, 6}
if !errors.Is(errors.Join(errs...), ErrCached) {
t.Errorf("errs = %v; want %v", errs, ErrCached)
}
if !slices.Equal(reads, want) {
t.Errorf("pairs = %v; want %v", reads, want)
}
}
func TestRegistryPullManifestNotFound(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotFound)
})
err := rc.Pull(t.Context(), "notfound")
checkErrCode(t, err, 404, "")
}
func TestRegistryPullResolveRemoteError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusInternalServerError)
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
})
err := rc.Pull(t.Context(), "single")
checkErrCode(t, err, 500, "an_error")
}
func TestRegistryPullResolveRoundtripError(t *testing.T) {
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
if strings.Contains(r.URL.Path, "/manifests/") {
w.WriteHeader(499) // force RoundTrip error
return
}
})
err := rc.Pull(t.Context(), "single")
if !errors.Is(err, errRoundTrip) {
t.Errorf("err = %v; want %v", err, errRoundTrip)
}
}
// TestRegistryPullMixedCachedNotCached tests that cached layers do not
// interfere with pulling layers that are not cached
func TestRegistryPullMixedCachedNotCached(t *testing.T) {
x := blob.DigestFromBytes("xxxxxx")
e := blob.DigestFromBytes("exists")
y := blob.DigestFromBytes("yyyyyy")
for i := range 10 {
t.Logf("iteration %d", i)
digests := []blob.Digest{x, e, y}
rand.Shuffle(len(digests), func(i, j int) {
digests[i], digests[j] = digests[j], digests[i]
})
manifest := fmt.Sprintf(`{
"layers": [
{"digest":"%s","size":6},
{"digest":"%s","size":6},
{"digest":"%s","size":6}
]
}`, digests[0], digests[1], digests[2])
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch path.Base(r.URL.Path) {
case "latest":
io.WriteString(w, manifest)
case x.String():
io.WriteString(w, "xxxxxx")
case e.String():
io.WriteString(w, "exists")
case y.String():
io.WriteString(w, "yyyyyy")
default:
panic(fmt.Sprintf("unexpected request: %v", r))
}
})
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("update %v %d %v", l, n, err)
},
})
// Check that we pull all layers that we can.
err := rc.Pull(ctx, "mixed")
if err != nil {
t.Fatal(err)
}
for _, d := range digests {
info, err := c.Get(d)
if err != nil {
t.Fatalf("Get(%v): %v", d, err)
}
if info.Size != 6 {
t.Errorf("info.Size = %v; want %v", info.Size, 6)
}
}
}
}
func TestRegistryResolveByDigest(t *testing.T) {
check := testutil.Checker(t)
@@ -590,26 +399,6 @@ func TestInsecureSkipVerify(t *testing.T) {
testutil.Check(t, err)
}
func TestCanRetry(t *testing.T) {
cases := []struct {
err error
want bool
}{
{nil, false},
{errors.New("x"), false},
{ErrCached, false},
{ErrManifestInvalid, false},
{ErrNameInvalid, false},
{&Error{Status: 100}, false},
{&Error{Status: 500}, true},
}
for _, tt := range cases {
if got := canRetry(tt.err); got != tt.want {
t.Errorf("CanRetry(%v) = %v; want %v", tt.err, got, tt.want)
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
@@ -761,17 +550,23 @@ func TestParseNameExtended(t *testing.T) {
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
check := testutil.Checker(t)
rc, _ := newRegistryClient(t, nil)
// make a blob and link it
d := blob.DigestFromBytes("{}")
err := blob.PutBytes(rc.Cache, d, "{}")
check(err)
err = rc.Cache.Link("registry.ollama.ai/library/single:latest", d)
check(err)
// confirm linked
_, err := rc.ResolveLocal("single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
_, err = rc.ResolveLocal("single")
check(err)
// unlink
_, err = rc.Unlink("single")
testutil.Check(t, err)
check(err)
// confirm unlinked
_, err = rc.ResolveLocal("single")
@@ -780,7 +575,7 @@ func TestUnlink(t *testing.T) {
}
})
t.Run("not found by name", func(t *testing.T) {
rc, _ := newClient(t, nil)
rc, _ := newRegistryClient(t, nil)
ok, err := rc.Unlink("manifestNotFound")
if err != nil {
t.Fatal(err)
@@ -791,78 +586,368 @@ func TestUnlink(t *testing.T) {
})
}
func TestPullChunksums(t *testing.T) {
check := testutil.Checker(t)
// Many tests from here out, in this file are based on a single blob, "abc",
// with the checksum of its sha256 hash. The checksum is:
//
// "abc" -> sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad
//
// Using the literal value instead of a constant with fmt.Xprintf calls proved
// to be the most readable and maintainable approach. The sum is consistently
// used in the tests and unique so searches do not yield false positives.
content := "hello"
var chunksums string
contentDigest := func() blob.Digest {
return blob.DigestFromBytes(content)
func checkRequest(t *testing.T, req *http.Request, method, path string) {
t.Helper()
if got := req.URL.Path; got != path {
t.Errorf("URL = %q, want %q", got, path)
}
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
switch {
case strings.Contains(r.URL.Path, "/manifests/latest"):
fmt.Fprintf(w, `{"layers":[{"digest":%q,"size":%d}]}`, contentDigest(), len(content))
case strings.HasSuffix(r.URL.Path, "/chunksums/"+contentDigest().String()):
loc := fmt.Sprintf("http://blob.store/v2/library/test/blobs/%s", contentDigest())
w.Header().Set("Content-Location", loc)
io.WriteString(w, chunksums)
case strings.Contains(r.URL.Path, "/blobs/"+contentDigest().String()):
http.ServeContent(w, r, contentDigest().String(), time.Time{}, strings.NewReader(content))
default:
t.Errorf("unexpected request: %v", r)
http.NotFound(w, r)
}
})
if req.Method != method {
t.Errorf("Method = %q, want %q", req.Method, method)
}
}
rc.MaxStreams = 1 // prevent concurrent chunk downloads
rc.ChunkingThreshold = 1 // for all blobs to be chunked
func newRegistryClient(t *testing.T, h http.HandlerFunc) (*Registry, context.Context) {
s := httptest.NewServer(h)
t.Cleanup(s.Close)
cache, err := blob.Open(t.TempDir())
if err != nil {
t.Fatal(err)
}
var mu sync.Mutex
var reads []int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Logf("Update: %v %d %v", l, n, err)
mu.Lock()
reads = append(reads, n)
mu.Unlock()
t.Log("trace:", l.Digest.Short(), n, err)
},
})
chunksums = fmt.Sprintf("%s 0-2\n%s 3-4\n",
blob.DigestFromBytes("hel"),
blob.DigestFromBytes("lo"),
)
err := rc.Pull(ctx, "test")
check(err)
wantReads := []int64{
0, // initial signaling of layer pull starting
3, // first chunk read
2, // second chunk read
}
if !slices.Equal(reads, wantReads) {
t.Errorf("reads = %v; want %v", reads, wantReads)
rc := &Registry{
Cache: cache,
HTTPClient: &http.Client{Transport: &http.Transport{
Dial: func(network, addr string) (net.Conn, error) {
return net.Dial(network, s.Listener.Addr().String())
},
}},
}
return rc, ctx
}
mw, err := rc.Resolve(t.Context(), "test")
check(err)
mg, err := rc.ResolveLocal("test")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
}
for i := range mg.Layers {
_, err = c.Get(mg.Layers[i].Digest)
if err != nil {
t.Errorf("Get(%v): %v", mg.Layers[i].Digest, err)
func TestPullChunked(t *testing.T) {
var steps atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch steps.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
checkRequest(t, r, "GET", "/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
t.Logf("writing c")
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", steps.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
}
})
// missing chunks
content = "llama"
chunksums = fmt.Sprintf("%s 0-1\n", blob.DigestFromBytes("ll"))
err = rc.Pull(ctx, "missingchunks")
if err == nil {
t.Error("expected error because of missing chunks")
c.ChunkingThreshold = 1 // force chunking
err := c.Pull(ctx, "http://o.com/library/abc")
testutil.Check(t, err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
testutil.Check(t, err)
if g := steps.Load(); g != 4 {
t.Fatalf("got %d steps, want 4", g)
}
}
func TestPullCached(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
})
check := testutil.Checker(t)
// Premeptively cache the blob
d, err := blob.ParseDigest("sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
check(err)
err = blob.PutBytes(c.Cache, d, []byte("abc"))
check(err)
// Pull only the manifest, which should be enough to resolve the cached blob
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
}
func TestPullManifestError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"MANIFEST_UNKNOWN"}]}`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var got *Error
if !errors.Is(err, ErrModelNotFound) {
t.Fatalf("err = %v, want %v", got, ErrModelNotFound)
}
}
func TestPullLayerError(t *testing.T) {
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `!`)
})
err := c.Pull(ctx, "http://o.com/library/abc")
if err == nil {
t.Fatalf("expected error")
}
var want *json.SyntaxError
if !errors.As(err, &want) {
t.Fatalf("err = %T, want %T", err, want)
}
}
func TestPullLayerChecksumError(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
checkRequest(t, r, "GET", "/v2/library/abc/chunksums/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3:
w.WriteHeader(http.StatusNotFound)
io.WriteString(w, `{"errors":[{"code":"BLOB_UNKNOWN"}]}`)
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1
c.ChunkingThreshold = 1 // force chunking
var written atomic.Int64
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
written.Add(n)
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
var got *Error
if !errors.As(err, &got) || got.Code != "BLOB_UNKNOWN" {
t.Fatalf("err = %v, want %v", err, got)
}
if g := written.Load(); g != 1 {
t.Fatalf("wrote %d bytes, want 1", g)
}
}
func TestPullChunksumStreamError(t *testing.T) {
var step atomic.Int64
c, ctx := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
// Write one valid chunksum and one invalid chunksum
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab")) // valid
fmt.Fprint(w, "sha256:!") // invalid
case 3:
io.WriteString(w, "ab")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
got := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(got, ErrIncomplete) {
t.Fatalf("err = %v, want %v", got, ErrIncomplete)
}
}
type flushAfterWriter struct {
w io.Writer
}
func (f *flushAfterWriter) Write(p []byte) (n int, err error) {
n, err = f.w.Write(p)
f.w.(http.Flusher).Flush() // panic if not a flusher
return
}
func TestPullChunksumStreaming(t *testing.T) {
csr, csw := io.Pipe()
defer csw.Close()
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fw := &flushAfterWriter{w} // ensure client gets data as it arrives by aggressively flushing
_, err := io.Copy(fw, csr)
if err != nil {
t.Errorf("copy: %v", err)
}
case 3:
io.WriteString(w, "ab")
case 4:
io.WriteString(w, "c")
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.ChunkingThreshold = 1 // force chunking
update := make(chan int64, 1)
ctx := WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if n > 0 {
update <- n
}
},
})
errc := make(chan error, 1)
go func() {
errc <- c.Pull(ctx, "http://o.com/library/abc")
}()
// Send first chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 0-1\n", blob.DigestFromBytes("ab"))
if g := <-update; g != 2 {
t.Fatalf("got %d, want 2", g)
}
// now send the second chunksum and ensure it kicks off work immediately
fmt.Fprintf(csw, "%s 2-2\n", blob.DigestFromBytes("c"))
if g := <-update; g != 1 {
t.Fatalf("got %d, want 1", g)
}
csw.Close()
testutil.Check(t, <-errc)
}
func TestPullChunksumsCached(t *testing.T) {
var step atomic.Int64
c, _ := newRegistryClient(t, func(w http.ResponseWriter, r *http.Request) {
switch step.Add(1) {
case 1:
checkRequest(t, r, "GET", "/v2/library/abc/manifests/latest")
io.WriteString(w, `{"layers":[{"size":3,"digest":"sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"}]}`)
case 2:
w.Header().Set("Content-Location", "http://blob.store/v2/library/abc/blobs/sha256:ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad")
fmt.Fprintf(w, "%s 0-1\n", blob.DigestFromBytes("ab"))
fmt.Fprintf(w, "%s 2-2\n", blob.DigestFromBytes("c"))
case 3, 4:
switch rng := r.Header.Get("Range"); rng {
case "bytes=0-1":
io.WriteString(w, "ab")
case "bytes=2-2":
io.WriteString(w, "c")
default:
t.Errorf("unexpected range %q", rng)
}
default:
t.Errorf("unexpected steps %d: %v", step.Load(), r)
http.Error(w, "unexpected steps", http.StatusInternalServerError)
}
})
c.MaxStreams = 1 // force serial processing of chunksums
c.ChunkingThreshold = 1 // force chunking
ctx, cancel := context.WithCancel(t.Context())
defer cancel()
// Cancel the pull after the first chunksum is processed, but before
// the second chunksum is processed (which is waiting because
// MaxStreams=1). This should cause the second chunksum to error out
// leaving the blob incomplete.
ctx = WithTrace(ctx, &Trace{
Update: func(l *Layer, n int64, err error) {
if n > 0 {
cancel()
}
},
})
err := c.Pull(ctx, "http://o.com/library/abc")
if !errors.Is(err, context.Canceled) {
t.Fatalf("err = %v, want %v", err, context.Canceled)
}
_, err = c.Cache.Resolve("o.com/library/abc:latest")
if !errors.Is(err, fs.ErrNotExist) {
t.Fatalf("err = %v, want nil", err)
}
// Reset state and pull again to ensure the blob chunks that should
// have been cached are, and the remaining chunk was downloaded, making
// the blob complete.
step.Store(0)
var written atomic.Int64
var cached atomic.Int64
ctx = WithTrace(t.Context(), &Trace{
Update: func(l *Layer, n int64, err error) {
t.Log("trace:", l.Digest.Short(), n, err)
if errors.Is(err, ErrCached) {
cached.Add(n)
}
written.Add(n)
},
})
check := testutil.Checker(t)
err = c.Pull(ctx, "http://o.com/library/abc")
check(err)
_, err = c.Cache.Resolve("o.com/library/abc:latest")
check(err)
if g := written.Load(); g != 3 {
t.Fatalf("wrote %d bytes, want 3", g)
}
if g := cached.Load(); g != 2 { // "ab" should have been cached
t.Fatalf("cached %d bytes, want 3", g)
}
}

View File

@@ -31,9 +31,10 @@ const (
var (
ErrInvalidImageFormat = errors.New("invalid image format")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrInvalidProtocol = errors.New("invalid protocol scheme")
ErrInsecureProtocol = errors.New("insecure protocol http")
ErrInvalidDigestFormat = errors.New("invalid digest format")
ErrModelPathInvalid = errors.New("invalid model path")
)
func ParseModelPath(name string) ModelPath {
@@ -73,8 +74,6 @@ func ParseModelPath(name string) ModelPath {
return mp
}
var errModelPathInvalid = errors.New("invalid model path")
func (mp ModelPath) GetNamespaceRepository() string {
return fmt.Sprintf("%s/%s", mp.Namespace, mp.Repository)
}

View File

@@ -777,7 +777,7 @@ func (s *Server) ShowHandler(c *gin.Context) {
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
name := model.ParseName(req.Model)
if !name.IsValid() {
return nil, errModelPathInvalid
return nil, ErrModelPathInvalid
}
name, err := getExistingName(name)
if err != nil {

View File

@@ -711,7 +711,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
req.opts.NumCtx = req.origNumCtx * p
if !envconfig.SchedSpread() {
for _, g := range sgl {
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
if ok, estimatedVRAM = llm.PredictServerFit([]discover.GpuInfo{g}, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
slog.Info("new model will fit in available VRAM in single GPU, loading", "model", req.model.ModelPath, "gpu", g.ID, "parallel", p, "available", g.FreeMemory, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return []discover.GpuInfo{g}
@@ -727,7 +727,7 @@ func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuIn
// Now try all the GPUs
for _, p := range numParallelToTry {
req.opts.NumCtx = req.origNumCtx * p
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts); ok {
if ok, estimatedVRAM = llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok {
slog.Info("new model will fit in available VRAM, loading", "model", req.model.ModelPath, "library", sgl[0].Library, "parallel", p, "required", format.HumanBytes2(estimatedVRAM))
*numParallel = p
return sgl
@@ -750,7 +750,7 @@ func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.Gp
var bestEstimate uint64
var bestFit int
for i, gl := range byLibrary {
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts)
_, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel)
if estimatedVRAM > bestEstimate {
bestEstimate = estimatedVRAM
bestFit = i
@@ -825,7 +825,7 @@ func (s *Scheduler) expireRunner(model *Model) {
// If not, pick a runner to unload, else return nil and the request can be loaded
func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef {
slog.Debug("evaluating if CPU model load will fit in available system memory")
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts)
estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx)
if estimate.TotalSize <= gpus[0].FreeMemory {
slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory))
return nil