Compare commits

..

1 Commits

Author SHA1 Message Date
jmorganca
19638cec55 add docs.json 2025-08-17 13:12:39 -07:00
26 changed files with 343 additions and 1349 deletions

View File

@@ -411,8 +411,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ollama launcher](https://github.com/NGC13009/ollama-launcher) (A launcher for Ollama, aiming to provide users with convenient functions such as ollama server launching, management, or configuration.)
- [ai-hub](https://github.com/Aj-Seven/ai-hub) (AI Hub supports multiple models via API keys and Chat support via Ollama API.)
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
### Cloud
@@ -539,8 +537,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Nichey](https://github.com/goodreasonai/nichey) is a Python package for generating custom wikis for your research topic
- [Ollama for D](https://github.com/kassane/ollama-d)
- [OllamaPlusPlus](https://github.com/HardCodeDev777/OllamaPlusPlus) (Very simple C++ library for Ollama)
- [any-llm](https://github.com/mozilla-ai/any-llm) (A single interface to use different llm providers by [mozilla.ai](https://www.mozilla.ai/))
- [any-agent](https://github.com/mozilla-ai/any-agent) (A single interface to use and evaluate different agent frameworks by [mozilla.ai](https://www.mozilla.ai/))
### Mobile

View File

@@ -90,10 +90,6 @@ type GenerateRequest struct {
// (request that thinking _not_ be used) and unset (use the old behavior
// before this option was introduced)
Think *ThinkValue `json:"think,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -124,10 +120,6 @@ type ChatRequest struct {
// responding. Can be a boolean (true/false) or a string ("high", "medium", "low")
// for supported models.
Think *ThinkValue `json:"think,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
}
type Tools []Tool
@@ -316,19 +308,6 @@ type ChatResponse struct {
Metrics
}
// DebugInfo contains debug information for template rendering
type DebugInfo struct {
RenderedTemplate string `json:"rendered_template"`
ImageCount int `json:"image_count,omitempty"`
}
// DebugTemplateResponse is returned when _debug_render_only is set to true
type DebugTemplateResponse struct {
Model string `json:"model"`
CreatedAt time.Time `json:"created_at"`
DebugInfo DebugInfo `json:"_debug_info"`
}
type Metrics struct {
TotalDuration time.Duration `json:"total_duration,omitempty"`
LoadDuration time.Duration `json:"load_duration,omitempty"`

75
docs/docs.json Normal file
View File

@@ -0,0 +1,75 @@
{
"$schema": "https://mintlify.com/docs.json",
"theme": "mint",
"background": {
"color": {
"light": "#ffffff",
"dark": "#000000"
}
},
"appearance": {
"default": "light"
},
"styling": {
"codeblocks": "system"
},
"contextual": {
"options": ["copy", "chatgpt", "claude", "view"]
},
"fonts": {
"heading": {
"family": "Inter"
},
"body": {
"family": "Inter"
}
},
"name": "Ollama",
"colors": {
"primary": "#000",
"light": "#b5b5b5",
"dark": "#fff"
},
"favicon": "/ollama.png",
"logo": {
"light": "/ollama.png",
"dark": "/favicon.svg"
},
"navigation": {
"tabs": [
{
"tab": "Documentation",
"groups": [
{
"group": "Home",
"pages": ["index", "quickstart", "faq", "troubleshooting"]
},
{
"group": "Platforms",
"pages": ["linux", "windows", "docker"]
},
{
"group": "Features",
"pages": [
"modelfile",
"apis",
"openai",
"import",
"gpu",
"benchmark"
]
}
]
},
{
"tab": "Development",
"groups": [
{
"group": " ",
"pages": ["development", "examples", "template"]
}
]
}
]
}
}

View File

@@ -2,13 +2,10 @@
This directory contains integration tests to exercise Ollama end-to-end to verify behavior
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...` Some tests require additional tags to enable to allow scoped testing to keep the duration reasonable. For example, testing a broad set of models requires `-tags=integration,models` and a longer timeout (~60m or more depending on the speed of your GPU.). To view the current set of tag combinations use `find integration -type f | xargs grep "go:build"`
By default, these tests are disabled so `go test ./...` will exercise only unit tests. To run integration tests you must pass the integration tag. `go test -tags=integration ./...`
The integration tests have 2 modes of operating.
1. By default, they will start the server on a random port, run the tests, and then shutdown the server.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote based on your `OLLAMA_HOST` environment variable
> [!IMPORTANT]
> Before running the tests locally without the "test existing" setting, compile ollama from the top of the source tree `go build .` in addition to GPU support with cmake if applicable on your platform. The integration tests expect to find an ollama binary at the top of the tree.
2. If `OLLAMA_TEST_EXISTING` is set to a non-empty string, the tests will run against an existing running server, which can be remote

View File

@@ -66,7 +66,7 @@ func TestContextExhaustion(t *testing.T) {
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
}
// Send multiple generate requests with prior context and ensure the response is coherant and expected
// Send multiple requests with prior context and ensure the response is coherant and expected
func TestGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := GenerateRequests()
@@ -111,56 +111,5 @@ func TestGenerateWithHistory(t *testing.T) {
}(i)
}
wg.Wait()
}
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2
softTimeout, hardTimeout := getTimeouts(t)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
// Get the server running (if applicable) warm the model up with a single initial empty request
slog.Info("loading", "model", modelOverride)
err := client.Generate(ctx,
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
func(response api.GenerateResponse) error { return nil },
)
if err != nil {
t.Fatalf("failed to load model %s: %s", modelOverride, err)
}
var wg sync.WaitGroup
wg.Add(numParallel)
for i := range numParallel {
go func(i int) {
defer wg.Done()
k := i % len(req)
req[k].Model = modelOverride
for j := 0; j < iterLimit; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
return
}
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
if assistant == nil {
t.Fatalf("didn't get an assistant response for context")
}
req[k].Messages = append(req[k].Messages,
*assistant,
api.Message{Role: "user", Content: "tell me more!"},
)
}
}(i)
}
wg.Wait()
}

View File

@@ -19,8 +19,6 @@ import (
)
func TestMaxQueue(t *testing.T) {
t.Skip("this test needs to be re-evaluated to use a proper embedding model")
if os.Getenv("OLLAMA_TEST_EXISTING") != "" {
t.Skip("Max Queue test requires spawning a local server so we can adjust the queue size")
return

View File

@@ -567,76 +567,6 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
}
}
func ChatRequests() ([]api.ChatRequest, [][]string) {
genReqs, results := GenerateRequests()
reqs := make([]api.ChatRequest, len(genReqs))
for i := range reqs {
reqs[i].Model = genReqs[i].Model
reqs[i].Stream = genReqs[i].Stream
reqs[i].KeepAlive = genReqs[i].KeepAlive
reqs[i].Messages = []api.Message{
{
Role: "user",
Content: genReqs[i].Prompt,
},
}
}
return reqs, results
}
func DoChat(ctx context.Context, t *testing.T, client *api.Client, req api.ChatRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) *api.Message {
stallTimer := time.NewTimer(initialTimeout)
var buf bytes.Buffer
role := "assistant"
fn := func(response api.ChatResponse) error {
// fmt.Print(".")
role = response.Message.Role
buf.Write([]byte(response.Message.Content))
if !stallTimer.Reset(streamTimeout) {
return errors.New("stall was detected while streaming response, aborting")
}
return nil
}
stream := true
req.Stream = &stream
done := make(chan int)
var genErr error
go func() {
genErr = client.Chat(ctx, &req, fn)
done <- 0
}()
select {
case <-stallTimer.C:
if buf.Len() == 0 {
t.Errorf("generate never started. Timed out after :%s", initialTimeout.String())
} else {
t.Errorf("generate stalled. Response so far:%s", buf.String())
}
case <-done:
if genErr != nil && strings.Contains(genErr.Error(), "model requires more system memory") {
slog.Warn("model is too large for the target test system", "model", req.Model, "error", genErr)
return nil
}
require.NoError(t, genErr, "failed with %s request Messages %s ", req.Model, req.Messages)
// Verify the response contains the expected data
response := buf.String()
atLeastOne := false
for _, resp := range anyResp {
if strings.Contains(strings.ToLower(response), resp) {
atLeastOne = true
break
}
}
require.True(t, atLeastOne, "%s: none of %v found in \"%s\" -- request was:%v", req.Model, anyResp, response, req.Messages)
slog.Info("test pass", "model", req.Model, "messages", req.Messages, "contains", anyResp, "response", response)
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")
}
return &api.Message{Role: role, Content: buf.String()}
}
func skipUnderMinVRAM(t *testing.T, gb uint64) {
// TODO use info API in the future
if s := os.Getenv("OLLAMA_MAX_VRAM"); s != "" {

View File

@@ -962,7 +962,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
const bool output_all = false;
// when computing embeddings, all tokens are output
const bool output_all = cparams.embeddings;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);

View File

@@ -13,7 +13,7 @@ checks.
1 file changed, 18 insertions(+)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 57eae461..c7f9dc3a 100644
index 57eae461..9db0c8b5 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -2671,12 +2671,24 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud

View File

@@ -1,23 +0,0 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: Michael Yang <git@mxy.ng>
Date: Mon, 18 Aug 2025 16:58:39 -0700
Subject: [PATCH] decode: disable output_all
---
src/llama-context.cpp | 3 +--
1 file changed, 1 insertion(+), 2 deletions(-)
diff --git a/src/llama-context.cpp b/src/llama-context.cpp
index 26a5cf9c..6ece5263 100644
--- a/src/llama-context.cpp
+++ b/src/llama-context.cpp
@@ -962,8 +962,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
const int64_t n_vocab = vocab.n_tokens();
const int64_t n_embd = hparams.n_embd;
- // when computing embeddings, all tokens are output
- const bool output_all = cparams.embeddings;
+ const bool output_all = false;
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);

View File

@@ -31,7 +31,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
@@ -652,9 +651,7 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
if !success {
s.initModel(ctx, LoadRequest{}, LoadOperationClose)
}
if s.mem != nil {
s.mem.Log(slog.LevelInfo)
}
s.mem.Log(slog.LevelInfo)
}()
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
@@ -1332,9 +1329,7 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
FunctionNameMap *harmony.FunctionNameMap
PrefillContent *bool
Grammar string // set before sending the request to the subprocess
}
// DoneReason represents the reason why a completion response is done
@@ -1361,15 +1356,13 @@ func (d DoneReason) String() string {
}
type CompletionResponse struct {
Content string `json:"content"`
Thinking string `json:"thinking"`
ToolCalls []api.ToolCall `json:"tool_calls"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
Content string `json:"content"`
DoneReason DoneReason `json:"done_reason"`
Done bool `json:"done"`
PromptEvalCount int `json:"prompt_eval_count"`
PromptEvalDuration time.Duration `json:"prompt_eval_duration"`
EvalCount int `json:"eval_count"`
EvalDuration time.Duration `json:"eval_duration"`
}
func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
@@ -1487,7 +1480,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("error unmarshalling llm prediction response: %v", err)
}
switch {
case lastToken != "" && (strings.TrimSpace(c.Content) == lastToken || strings.TrimSpace(c.Thinking) == lastToken):
case strings.TrimSpace(c.Content) == lastToken:
tokenRepeat++
default:
lastToken = strings.TrimSpace(c.Content)
@@ -1500,14 +1493,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return ctx.Err()
}
if c.Content != "" {
fn(CompletionResponse{
Content: c.Content,
})
}
if c.Done {
fn(c)
return nil
}
if c.Content != "" || c.Thinking != "" || len(c.ToolCalls) > 0 {
fn(c)
}
}
}

View File

@@ -400,8 +400,6 @@ type Tensor interface {
Bytes() []byte
Floats() []float32
BackendSetFromIntSlice(s []int32)
Neg(ctx Context) Tensor
Add(ctx Context, t2 Tensor) Tensor
Sub(ctx Context, t2 Tensor) Tensor

View File

@@ -82,7 +82,6 @@ type Backend struct {
// to the name that is used by the model definition
tensorLoadTargets map[string][]string
schedMu sync.Mutex // Only one Compute can run at a time
sched C.ggml_backend_sched_t
schedBackends []C.ggml_backend_t
schedBufts []C.ggml_backend_buffer_type_t
@@ -770,8 +769,6 @@ func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
}
func (c *Context) Compute(tensors ...ml.Tensor) {
c.b.schedMu.Lock()
defer c.b.schedMu.Unlock()
if status := C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph); status != C.GGML_STATUS_SUCCESS {
panic(fmt.Errorf("error computing ggml graph: %v", status))
}
@@ -1040,12 +1037,6 @@ func (t *Tensor) Floats() (data []float32) {
return
}
func (t *Tensor) BackendSetFromIntSlice(s []int32) {
if len(s) > 0 {
C.ggml_backend_tensor_set(t.t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.t))
}
}
func (t *Tensor) DType() ml.DType {
switch t.t._type {
case C.GGML_TYPE_F32:

View File

@@ -64,7 +64,7 @@ type MultimodalProcessor interface {
// This function is also responsible for updating MultimodalHash for any Multimodal
// that is modified to ensure that there is a unique hash value that accurately
// represents the contents.
PostTokenize([]*input.Input) ([]*input.Input, error)
PostTokenize([]input.Input) ([]input.Input, error)
}
// Base implements the common fields and methods for all models
@@ -278,13 +278,13 @@ func canNil(t reflect.Type) bool {
t.Kind() == reflect.Slice
}
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, ml.Tensor, error) {
func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Tensor, error) {
if len(batch.Positions) != len(batch.Sequences) {
return nil, nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(batch.Positions), len(batch.Sequences))
}
if len(batch.Positions) < 1 {
return nil, nil, errors.New("batch size cannot be less than 1")
return nil, errors.New("batch size cannot be less than 1")
}
batch.Inputs = ctx.Input().FromIntSlice(inputs, len(inputs))
@@ -293,16 +293,16 @@ func Forward(ctx ml.Context, m Model, inputs []int32, batch input.Batch) (ml.Ten
if cache != nil {
err := cache.StartForward(ctx, batch, false)
if err != nil {
return nil, nil, err
return nil, err
}
}
t, err := m.Forward(ctx, batch)
if err != nil {
return nil, nil, err
return nil, err
}
ctx.Forward(t)
ctx.Forward(t).Compute(t)
return batch.Inputs, t, nil
return t, nil
}

View File

@@ -112,8 +112,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: visionOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
@@ -122,17 +122,17 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
inputMultimodal := inp.Multimodal[0].Tensor
result = append(result,
&input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
&input.Input{Token: 255999}, // "<start_of_image>""
&input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
input.Input{Token: 108, SameBatch: inputMultimodal.Dim(1) + 3}, // "\n\n"
input.Input{Token: 255999}, // "<start_of_image>""
input.Input{Multimodal: []input.Multimodal{{Tensor: inputMultimodal}}, MultimodalHash: inp.MultimodalHash}, // image data is on the first placeholder
)
// add image token placeholders
result = append(result, slices.Repeat([]*input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
result = append(result,
&input.Input{Token: 256000}, // <end_of_image>
&input.Input{Token: 108}, // "\n\n"
input.Input{Token: 256000}, // <end_of_image>
input.Input{Token: 108}, // "\n\n"
)
}
}

View File

@@ -134,16 +134,16 @@ type separator struct {
y bool
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
continue
}
var imageInputs []*input.Input
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_start|>
var imageInputs []input.Input
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_start|>
for i, mm := range inp.Multimodal {
patchesPerChunk := mm.Tensor.Dim(1)
@@ -151,20 +151,20 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
if i < len(inp.Multimodal)-1 {
separator := mm.Data.(*separator)
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
if separator.x {
imageInputs = append(imageInputs, &input.Input{Token: 200084}) // <|tile_x_separator|>
imageInputs = append(imageInputs, input.Input{Token: 200084}) // <|tile_x_separator|>
}
if separator.y {
imageInputs = append(imageInputs, &input.Input{Token: 200085}) // <|tile_y_separator|>
imageInputs = append(imageInputs, input.Input{Token: 200085}) // <|tile_y_separator|>
}
} else {
imageInputs = append(imageInputs, &input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, &input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]*input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, &input.Input{Token: 200080}) // <|image_end|>
imageInputs = append(imageInputs, input.Input{Token: 200090}) // <|image|>
imageInputs = append(imageInputs, input.Input{Token: 200092, Multimodal: []input.Multimodal{{Tensor: mm.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: patchesPerChunk}) // <|patch|>
imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 200092}}, patchesPerChunk-1)...)
imageInputs = append(imageInputs, input.Input{Token: 200080}) // <|image_end|>
}
}

View File

@@ -133,22 +133,22 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
// Each sequence of [IMG]...[IMG] is a set of patches of vision embeddings
// that can be processed together.
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if len(inp.Multimodal) == 0 {
result = append(result, inp)
} else {
for i, row := range inp.Multimodal {
// [IMG]
result = append(result, &input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]*input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
result = append(result, input.Input{Token: 10, Multimodal: []input.Multimodal{{Tensor: row.Tensor}}, MultimodalHash: inp.MultimodalHash, SameBatch: row.Tensor.Dim(1)})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Tensor.Dim(1)-1)...)
if i == len(inp.Multimodal)-1 {
// [IMG_END]
result = append(result, &input.Input{Token: 13})
result = append(result, input.Input{Token: 13})
} else {
// [IMG_BREAK]
result = append(result, &input.Input{Token: 12})
result = append(result, input.Input{Token: 12})
}
}
}

View File

@@ -90,7 +90,7 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
return []input.Multimodal{{Tensor: projectedOutputs}}, nil
}
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
for i := range inputs {
if inputs[i].Multimodal != nil {
inputs[i].Token = 128256 // <|image|>

View File

@@ -89,8 +89,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input
}
// PostTokenize arranges Qwen-2.5-VL's inputs for the forward pass
func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
var result []*input.Input
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
var (
imageToken int32 = 151655
@@ -112,16 +112,16 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
return nil, fmt.Errorf("failed to encode image prompt: %w", err)
}
for i := range pre {
result = append(result, &input.Input{Token: pre[i]})
result = append(result, input.Input{Token: pre[i]})
}
patchesPerChunk := inp.Multimodal[0].Tensor.Dim(1)
// First add the vision start token
result = append(result, &input.Input{Token: visionStartToken})
result = append(result, input.Input{Token: visionStartToken})
// Add the image token with the multimodal tensor data at the first position
result = append(result, &input.Input{
result = append(result, input.Input{
Token: imageToken,
Multimodal: inp.Multimodal,
MultimodalHash: inp.MultimodalHash,
@@ -129,9 +129,9 @@ func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) {
})
// Add the placeholder tokens for the remaining positions (tokensPerGrid-1)
result = append(result, slices.Repeat([]*input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, slices.Repeat([]input.Input{{Token: imageToken}}, patchesPerChunk-1)...)
result = append(result, &input.Input{Token: visionEndToken})
result = append(result, input.Input{Token: visionEndToken})
}
}

View File

@@ -86,7 +86,7 @@ type InputCacheSlot struct {
Id int
// Inputs that are stored in the KV cache
Inputs []*input.Input
Inputs []input.Input
// is this cache actively being processed as part of a sequence?
InUse bool
@@ -95,7 +95,7 @@ type InputCacheSlot struct {
lastUsed time.Time
}
func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*input.Input, error) {
func (c *InputCache) LoadCacheSlot(prompt []input.Input) (*InputCacheSlot, []input.Input, error) {
var slot *InputCacheSlot
var numPast int32
var err error
@@ -146,7 +146,7 @@ func (c *InputCache) LoadCacheSlot(prompt []*input.Input) (*InputCacheSlot, []*i
return slot, prompt, nil
}
func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
longest := int32(-1)
var longestSlot *InputCacheSlot
@@ -169,7 +169,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []*input.Input) (*InputCacheSlo
return longestSlot, longest, nil
}
func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot, int32, error) {
func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) {
oldest := time.Now()
var oldestSlot *InputCacheSlot
@@ -205,7 +205,7 @@ func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot,
if longest > 0 && longestSlot != oldestSlot {
slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total",
len(longestSlot.Inputs))
oldestSlot.Inputs = make([]*input.Input, longest)
oldestSlot.Inputs = make([]input.Input, longest)
copy(oldestSlot.Inputs, longestSlot.Inputs[:longest])
if c.cache != nil {
c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest)
@@ -215,7 +215,7 @@ func (c *InputCache) findBestCacheSlot(prompt []*input.Input) (*InputCacheSlot,
return oldestSlot, longest, nil
}
func countCommonPrefix(a []*input.Input, b []*input.Input) int32 {
func countCommonPrefix(a []input.Input, b []input.Input) int32 {
var count int32
for i := range a {
@@ -250,7 +250,7 @@ func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 {
}
type ErrReprocessInputs struct {
Inputs []*input.Input
Inputs []input.Input
}
func (e *ErrReprocessInputs) Error() string {
@@ -283,13 +283,13 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error {
"id", slot.Id, "error", err)
// Create new input slice with preserved tokens (numKeep + remaining tokens after discard)
newInputs := make([]*input.Input, numKeep+inputLen-(numKeep+discard))
newInputs := make([]input.Input, numKeep+inputLen-(numKeep+discard))
copy(newInputs[:numKeep], slot.Inputs[:numKeep])
copy(newInputs[numKeep:], slot.Inputs[numKeep+discard:])
// Reset the cache
_ = c.cache.Remove(slot.Id, 0, math.MaxInt32)
slot.Inputs = []*input.Input{}
slot.Inputs = []input.Input{}
// Return error with inputs that need to be reprocessed
return &ErrReprocessInputs{Inputs: newInputs}

View File

@@ -13,50 +13,50 @@ import (
func TestCountCommon(t *testing.T) {
tests := []struct {
name string
t1 []*input.Input
t2 []*input.Input
t1 []input.Input
t2 []input.Input
expected int32
}{
{
name: "Equal",
t1: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 3,
},
{
name: "Prefix",
t1: []*input.Input{{Token: 1}},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{{Token: 1}},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 1,
},
{
name: "Image Prefix",
t1: []*input.Input{{MultimodalHash: 1}},
t2: []*input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
t1: []input.Input{{MultimodalHash: 1}},
t2: []input.Input{{MultimodalHash: 1}, {MultimodalHash: 2}, {MultimodalHash: 3}},
expected: 1,
},
{
name: "Mixed",
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 1}, {Token: 5}},
expected: 2,
},
{
name: "Mixed, Same Length",
t1: []*input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []*input.Input{{Token: 1}, {MultimodalHash: 2}},
t1: []input.Input{{Token: 1}, {MultimodalHash: 1}},
t2: []input.Input{{Token: 1}, {MultimodalHash: 2}},
expected: 1,
},
{
name: "Empty",
t1: []*input.Input{},
t2: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
t1: []input.Input{},
t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
expected: 0,
},
{
name: "Both Empty",
t1: []*input.Input{},
t2: []*input.Input{},
t1: []input.Input{},
t2: []input.Input{},
expected: 0,
},
}
@@ -80,7 +80,7 @@ func TestFindCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []*input.Input
prompt []input.Input
longest expected
best expected
}{
@@ -89,18 +89,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
{
Id: 1,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []*input.Input{{Token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 0},
best: expected{result: 0, len: 0},
},
@@ -109,18 +109,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 2},
best: expected{result: 1, len: 2},
},
@@ -129,18 +129,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
}},
prompt: []*input.Input{{Token: 2}},
prompt: []input.Input{{Token: 2}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -150,19 +150,19 @@ func TestFindCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Time{},
},
},
},
prompt: []*input.Input{{Token: 1}},
prompt: []input.Input{{Token: 1}},
longest: expected{result: 0, len: 1},
best: expected{result: 1, len: 1},
},
@@ -171,18 +171,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []*input.Input{{Token: 2}, {Token: 3}},
prompt: []input.Input{{Token: 2}, {Token: 3}},
longest: expected{result: 0, len: 0},
best: expected{result: 1, len: 0},
},
@@ -191,18 +191,18 @@ func TestFindCacheSlot(t *testing.T) {
cache: InputCache{slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{{Token: 1}},
Inputs: []input.Input{{Token: 1}},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
}},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
longest: expected{result: 1, len: 1},
best: expected{result: 1, len: 2},
},
@@ -300,7 +300,7 @@ func TestLoadCacheSlot(t *testing.T) {
tests := []struct {
name string
cache InputCache
prompt []*input.Input
prompt []input.Input
wantErr bool
expectedSlotId int
expectedPrompt int // expected length of remaining prompt
@@ -312,19 +312,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -336,19 +336,19 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
{
Id: 1,
Inputs: []*input.Input{},
Inputs: []input.Input{},
InUse: false,
lastUsed: time.Now().Add(-2 * time.Second),
},
},
},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Only token 3 remains
@@ -360,13 +360,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: false,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []*input.Input{{Token: 1}, {Token: 2}},
prompt: []input.Input{{Token: 1}, {Token: 2}},
wantErr: false,
expectedSlotId: 0,
expectedPrompt: 1, // Should leave 1 token for sampling
@@ -378,13 +378,13 @@ func TestLoadCacheSlot(t *testing.T) {
slots: []InputCacheSlot{
{
Id: 0,
Inputs: []*input.Input{{Token: 1}, {Token: 2}},
Inputs: []input.Input{{Token: 1}, {Token: 2}},
InUse: true,
lastUsed: time.Now().Add(-time.Second),
},
},
},
prompt: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
prompt: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}},
wantErr: true,
expectedSlotId: -1,
expectedPrompt: -1,
@@ -452,7 +452,7 @@ func TestShiftCacheSlot(t *testing.T) {
tests := []struct {
name string
numCtx int32
inputs []*input.Input
inputs []input.Input
numKeep int32
cacheErr bool
wantErr any
@@ -461,7 +461,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Normal shift",
numCtx: 10,
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: false, // No error
wantErr: nil,
@@ -470,7 +470,7 @@ func TestShiftCacheSlot(t *testing.T) {
{
name: "Cache removal fails",
numCtx: 10,
inputs: []*input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
inputs: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}, {Token: 4}, {Token: 5}, {Token: 6}, {Token: 7}, {Token: 8}, {Token: 9}, {Token: 10}},
numKeep: 2,
cacheErr: true,
wantErr: &ErrReprocessInputs{},
@@ -487,7 +487,7 @@ func TestShiftCacheSlot(t *testing.T) {
}
slot := &InputCacheSlot{
Id: 123,
Inputs: make([]*input.Input, len(tt.inputs)),
Inputs: make([]input.Input, len(tt.inputs)),
}
copy(slot.Inputs, tt.inputs)

View File

@@ -17,7 +17,6 @@ import (
"reflect"
"regexp"
"runtime"
"runtime/debug"
"strconv"
"strings"
"sync"
@@ -29,7 +28,6 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
@@ -53,10 +51,10 @@ type Sequence struct {
iBatch int
// prompt inputs left to evaluate
inputs []*input.Input
inputs []input.Input
// inputs that have been added to a batch but not yet submitted to Forward
pendingInputs []*input.Input
pendingInputs []input.Input
// tokens that have been generated but not returned yet (e.g. for stop sequences)
pendingResponses []string
@@ -88,12 +86,6 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// true if the sequence if finished and marked for removal on next pass
finished bool
// True if we have to skip this sequence to shift the cache
skipForShift bool
doneReason llm.DoneReason
// Metrics
@@ -190,8 +182,8 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// inputs processes the prompt and images into a list of inputs
// by splitting the prompt on [img-<n>] tags, tokenizing text and
// decoding images
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, []ml.Context, multimodalStore, error) {
var inputs []*input.Input
func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, []ml.Context, multimodalStore, error) {
var inputs []input.Input
var ctxs []ml.Context
var mmStore multimodalStore
@@ -218,7 +210,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
}
for _, t := range tokens {
inputs = append(inputs, &input.Input{Token: t})
inputs = append(inputs, input.Input{Token: t})
}
// image - decode and store
@@ -251,7 +243,7 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
mmStore.addMultimodal(imageEmbeddings)
inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash})
postTokenize = true
}
}
@@ -267,27 +259,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input,
return inputs, ctxs, mmStore, nil
}
type batchState struct {
id int
ctx ml.Context
modelInput ml.Tensor
modelOutput ml.Tensor
batchInputs []*input.Input
batch input.Batch
seqs []*Sequence // full set of seqs at the time this batch was initiated
initSeqIdx int // The initial value for the set of sequences evaluated (s.nextSeq - 1)
// Signaled when this batches inputs are ready and compute can proceed
inputsReadyCh chan struct{}
// Signaling when Compute is about to begin on this batch, and
// seqs have been updated to prepare for the next batch
computeStartedCh chan struct{}
// Signaled when this batches outputs are complete and the next batch can proceed
outputsReadyCh chan struct{}
}
type Server struct {
// modelPath is the location of the model to be loaded
modelPath string
@@ -319,16 +290,6 @@ type Server struct {
// TODO (jmorganca): make this n_batch
batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// A prior batch that's still being processed
// only read or written by forwardBatch
pendingBatch *batchState
// Simple counter used only for trace logging batches
batchID int
// protects access to everything below this line
// this is context state needed for decoding
mu sync.Mutex
@@ -389,132 +350,45 @@ func flushPending(seq *Sequence) bool {
}
}
func (s *Server) finishSequence(seqIndex int, reason llm.DoneReason) {
func (s *Server) removeSequence(seqIndex int, reason llm.DoneReason) {
seq := s.seqs[seqIndex]
// finish could be called multiple times since we prepare 1 batch ahead
// and multiple scenarios can lead to finishing a sequence
// ensure only the first finish called is processed
if seq.finished {
return
}
flushPending(seq)
seq.doneReason = reason
seq.finished = true
close(seq.responses)
close(seq.embedding)
seq.cache.InUse = false
}
func (s *Server) removeFinishedSequence(seqIndex int) {
s.seqs[seqIndex] = nil
s.seqsSem.Release(1)
}
// track batch state between forwardBatch, computeBatch and predictForwardBatch
func (s *Server) run(ctx context.Context) {
s.ready.Wait()
var bs *batchState
for {
select {
case <-ctx.Done():
return
case err := <-s.hardErrCh:
panic(err)
default:
var err error
bs, err = s.forwardBatch()
err := s.processBatch()
if err != nil {
panic(err)
}
if bs == nil {
continue
}
go s.computeBatch(bs)
}
}
}
// forwardBatch will calculate a batch.
func (s *Server) forwardBatch() (*batchState, error) {
inputsReady := false
var inputsReadyCh chan struct{}
// If we have a pending batch still processing, wait until Compute has started
// before setting up the next batch so the seqs inputs are ready to receive their
// token values and we get the correct input pointers for the batchInputs
if s.pendingBatch != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch waiting for compute to start", "pendingBatch.id", s.pendingBatch.id)
<-s.pendingBatch.computeStartedCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch compute started, setting up next batch", "pendingBatch.id", s.pendingBatch.id, "id", s.batchID)
inputsReadyCh = s.pendingBatch.outputsReadyCh // Chain the ouputs from the pending batch to the next inputs batch
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no pending batch detected", "batchID", s.batchID)
inputsReady = true // No pendingBatch, so the inputs will be ready in the seqs immediately
inputsReadyCh = make(chan struct{}, 1)
}
func (s *Server) processBatch() error {
s.mu.Lock()
for s.allNil() {
s.cond.Wait() // Wait until an item is added
}
defer s.mu.Unlock()
// If new sequences have been added with an active batch we delay preparing the next batch
// until Compute has finished
if s.pendingBatch != nil {
for seqIdx := range s.seqs {
if s.seqs[seqIdx] != s.pendingBatch.seqs[seqIdx] {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch seqs changed, waiting for compute to finish to pick up new sequence(s)", "pendingBatch.id", s.pendingBatch.id)
s.mu.Unlock() // release the lock so computeBatch can finish up
<-s.pendingBatch.outputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch pending batch outputs ready", "pendingBatch.id", s.pendingBatch.id)
s.mu.Lock()
inputsReady = true // pendingBatch completed, so the inputs are ready in the seqs
break
}
}
}
// Clear pending Batch - we'll set it if we have a batch with any inputs
s.pendingBatch = nil
// Remove any finished sequences before recording the active set of seqs in the batch
for seqIdx := range s.seqs {
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
if seq.finished {
s.removeFinishedSequence(seqIdx)
continue
}
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.finishSequence(seqIdx, llm.DoneReasonLength)
s.removeFinishedSequence(seqIdx)
continue
}
}
// next batch
nb := &batchState{
id: s.batchID,
initSeqIdx: s.nextSeq - 1,
seqs: make([]*Sequence, len(s.seqs)),
inputsReadyCh: inputsReadyCh,
computeStartedCh: make(chan struct{}, 1),
outputsReadyCh: make(chan struct{}, 1),
}
ctx := s.model.Backend().NewContext()
nb.ctx = ctx
defer ctx.Close()
// Record the sequences at the time we create the batch so we can detect if new sequences are added on the next pass
copy(nb.seqs, s.seqs)
// Prepare the seqs and batch, but defer the input token values as we may not be ready yet
var batchInputs []*input.Input
var batchInputs []int32
var batch input.Batch
resumeSeq := -1
@@ -522,13 +396,20 @@ func (s *Server) forwardBatch() (*batchState, error) {
for range s.seqs {
seqIdx = (seqIdx + 1) % len(s.seqs)
seq := s.seqs[seqIdx]
if seq == nil {
continue
}
// if past the num predict limit
if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict {
s.removeSequence(seqIdx, llm.DoneReasonLength)
continue
}
if !s.cache.enabled {
seq.inputs = append(seq.cache.Inputs, seq.inputs...)
seq.cache.Inputs = []*input.Input{}
seq.cache.Inputs = []input.Input{}
}
batchSize := s.batchSize
@@ -568,21 +449,18 @@ func (s *Server) forwardBatch() (*batchState, error) {
// Prepend these inputs to the sequence's inputs queue for reprocessing
seq.inputs = append(reprocess.Inputs, seq.inputs...)
// Skip this sequence but continue processing the rest
seq.skipForShift = true // cleared in computeBatch below for the next batch
continue
} else {
ctx.Close()
return nil, err
return err
}
}
}
batchInputs = append(batchInputs, seq.inputs[i])
batchInputs = append(batchInputs, inp.Token)
if inp.Multimodal != nil {
mm, err := seq.mmStore.getMultimodal(s.model.Backend(), ctx, inp.Multimodal, false)
if err != nil {
ctx.Close()
return nil, err
return err
}
batch.Multimodal = append(batch.Multimodal, input.MultimodalIndex{Index: len(batchInputs) - 1, Multimodal: mm})
}
@@ -590,13 +468,10 @@ func (s *Server) forwardBatch() (*batchState, error) {
batch.Positions = append(batch.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs)))
batch.Sequences = append(batch.Sequences, seq.cache.Id)
// TODO BUG HERE!!!
// Somehow sometimes iBatch isn't set correctly
seq.iBatch = len(batch.Outputs)
if i+1 == len(seq.inputs) {
batch.Outputs = append(batch.Outputs, int32(len(batchInputs)-1))
}
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch iBatch", "batchID", s.batchID, "seqIdx", seqIdx, "seq.iBatch", seq.iBatch, "i+1", i+1, "len(seq.inputs)", len(seq.inputs))
seq.pendingInputs = append(seq.pendingInputs, inp)
}
@@ -610,138 +485,36 @@ func (s *Server) forwardBatch() (*batchState, error) {
}
if len(batchInputs) == 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "forwardBatch no batchInputs, going idle", "batchID", s.batchID)
ctx.Close()
return nil, nil
return nil
}
s.batchID++
var err error
// Actual batchInputs values will be injected into the modelInput tensor before calling Compute
nb.modelInput, nb.modelOutput, err = model.Forward(ctx, s.model, make([]int32, len(batchInputs)), batch)
modelOutput, err := model.Forward(ctx, s.model, batchInputs, batch)
if err != nil {
ctx.Close()
return nil, fmt.Errorf("failed to build graph: %w", err)
}
nb.batchInputs = batchInputs
nb.batch = batch
// computeBatch will close the context in the batch upon completion
s.pendingBatch = nb
if inputsReady {
nb.inputsReadyCh <- struct{}{}
return fmt.Errorf("failed to decode batch: %w", err)
}
return nb, nil
}
logits := modelOutput.Floats()
// Async processing of the next batch
func (s *Server) computeBatch(bs *batchState) {
if bs == nil || bs.ctx == nil {
// Nothing to compute
return
}
defer bs.ctx.Close()
// Wait until inputs are ready
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: waiting for inputs to be ready", "batchID", bs.id)
<-bs.inputsReadyCh
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: inputs are ready", "batchID", bs.id)
// Once we complete, signal the next batch of inputs are ready
// This will unblock the next computeBatch, or forwardBatch if new seqs come in
defer func() {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: outputs are ready", "batchID", bs.id)
bs.outputsReadyCh <- struct{}{}
}()
s.mu.Lock()
// Gather the actual input token values now that they're ready
batchInputs := make([]int32, len(bs.batchInputs))
for i := range batchInputs {
batchInputs[i] = bs.batchInputs[i].Token
}
// TODO the following logic could be run in a go routine to possibly speed up getting to Compute
// Now we run part of the decoding algorithm to adjust the seq.inputs with placeholder tokens
// so that forwardBatch can build a batchInputs set which will eventually contain the actual
// decoded tokens.
promptProcessing := make([]bool, len(s.seqs)) // track seq's we skip
nextBatchTokens := make([]*input.Input, len(s.seqs))
iBatches := make([]int, len(s.seqs)) // Record the iBatch values before releasing the lock
for i, seq := range s.seqs {
iBatches[i] = -1
if seq == nil {
continue
}
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// After calling Forward, pending inputs are now in the cache
if len(seq.pendingInputs) > 0 {
seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...)
seq.pendingInputs = []*input.Input{}
seq.pendingInputs = []input.Input{}
}
// don't sample prompt processing
if len(seq.inputs) != 0 {
if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
return
return errors.New("caching disabled but unable to fit entire input in a batch")
}
// Record so we can skip during Decode
promptProcessing[i] = true
continue
}
seq.numPredicted++
nextToken := &input.Input{Token: 0} // placeholder we'll fill in after Compute/Floats
seq.inputs = []*input.Input{nextToken}
nextBatchTokens[i] = nextToken
iBatches[i] = seq.iBatch
}
// At this point the seqs are ready for forwardBatch to move forward so unblock
s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: signaling computeStartedCh", "batchID", bs.id)
bs.computeStartedCh <- struct{}{}
bs.modelInput.BackendSetFromIntSlice(batchInputs)
bs.ctx.Compute(bs.modelOutput)
logits := bs.modelOutput.Floats()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: logits ready", "batchID", bs.id)
s.mu.Lock()
defer s.mu.Unlock()
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: decoding", "batchID", bs.id)
for i, seq := range s.seqs {
if seq == nil {
continue
}
// Skip over any newly added sequences
if bs.seqs[i] == nil {
continue
}
// Detect if the sequence we're processing has already been completed and replaced
// with a new sequence
if seq != bs.seqs[i] {
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: sequence replaced, discarding its results", "batchID", bs.id, "seqIdx", i)
continue
}
// don't sample prompt processing
if promptProcessing[i] {
continue
}
if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now()
}
@@ -749,46 +522,35 @@ func (s *Server) computeBatch(bs *batchState) {
// if done processing the prompt, generate an embedding and return
if seq.embeddingOnly {
// TODO(jessegross): Embedding support
slog.Warn("generation of embedding outputs not yet supported", "id", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
slog.Warn("generation of embedding outputs not yet supported")
s.removeSequence(i, llm.DoneReasonStop)
continue
}
// sample a token
vocabSize := len(logits) / len(bs.batch.Outputs)
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: vocab details", "batchID", bs.id, "seqIdx", i, "len(logits)", len(logits), "len(bs.batch.Outputs)", len(bs.batch.Outputs), "vocabSize", vocabSize, "seq.iBatch", seq.iBatch)
token, err := seq.sampler.Sample(logits[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
}
vocabSize := len(logits) / len(batch.Outputs)
nextBatchTokens[i].Token = token
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
if err != nil {
return fmt.Errorf("failed to sample token: %w", err)
}
// if it's an end of sequence token, break
if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) {
// TODO (jmorganca): we should send this back
// as it's important for the /api/generate context
// seq.responses <- piece
slog.Log(context.TODO(), logutil.LevelTrace, "computeBatch: EOS", "batchID", bs.id, "seqIdx", i)
s.finishSequence(i, llm.DoneReasonStop)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
return err
}
if nextBatchTokens[i] == nil {
slog.Error("batch corrupted", "id", bs.id, "batch", bs.batch, "seqIdx", i, "seq", seq)
s.hardErrCh <- fmt.Errorf("expected a single token during decode")
return
}
// fill in the final selected token value to replace the placeholder in the next batch
// nextBatchTokensWritten++
seq.inputs = []input.Input{{Token: token}}
seq.pendingResponses = append(seq.pendingResponses, piece)
sequence := strings.Join(seq.pendingResponses, "")
@@ -813,10 +575,9 @@ func (s *Server) computeBatch(bs *batchState) {
if tokenTruncated || origLen == newLen {
tokenLen--
}
seq.cache.Inputs = seq.cache.Inputs[:tokenLen]
s.finishSequence(i, llm.DoneReasonStop)
s.removeSequence(i, llm.DoneReasonStop)
continue
}
@@ -829,9 +590,11 @@ func (s *Server) computeBatch(bs *batchState) {
}
if !flushPending(seq) {
s.finishSequence(i, llm.DoneReasonConnectionClosed)
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
return nil
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -841,15 +604,6 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
var harmonyMessageHandler *harmony.HarmonyMessageHandler
var harmonyToolParser *harmony.HarmonyToolCallAccumulator
if req.FunctionNameMap != nil {
harmonyMessageHandler = harmony.NewHarmonyMessageHandler()
harmonyMessageHandler.FunctionNameMap = req.FunctionNameMap
harmonyMessageHandler.HarmonyParser.AddImplicitStartOrPrefill(req.PrefillContent)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
if req.Options == nil {
opts := api.DefaultOptions()
req.Options = &opts
@@ -940,16 +694,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
case content, ok := <-seq.responses:
if ok {
var thinking string
if harmonyMessageHandler != nil {
var toolContent string
content, thinking, toolContent = harmonyMessageHandler.AddContent(content, harmonyToolParser)
harmonyToolParser.Add(toolContent)
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
Content: content,
Thinking: thinking,
Content: content,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
close(seq.quit)
@@ -958,29 +704,7 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
flusher.Flush()
} else {
var toolCalls []api.ToolCall
if harmonyMessageHandler != nil {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
*toolName = harmonyMessageHandler.FunctionNameMap.OriginalFromConverted(*toolName)
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
http.Error(w, fmt.Sprintf("failed to unmarshal tool call function arguments: %v", err), http.StatusInternalServerError)
close(seq.quit)
return
}
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
if err := json.NewEncoder(w).Encode(&llm.CompletionResponse{
ToolCalls: toolCalls,
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
@@ -1012,10 +736,7 @@ func (s *Server) reserveWorstCaseGraph() error {
defer ctx.Close()
var err error
inputs := make([]*input.Input, s.batchSize)
for i := range inputs {
inputs[i] = &input.Input{}
}
inputs := make([]input.Input, s.batchSize)
mmStore := newMultimodalStore()
// Multimodal strategy:
@@ -1057,11 +778,8 @@ func (s *Server) reserveWorstCaseGraph() error {
}
if len(inputs) < s.batchSize {
newInputs := make([]*input.Input, s.batchSize)
newInputs := make([]input.Input, s.batchSize)
copy(newInputs, inputs)
for i := len(inputs); i < s.batchSize; i++ {
newInputs[i] = &input.Input{}
}
inputs = newInputs
}
}
@@ -1124,7 +842,6 @@ func (s *Server) allocModel(
// Convert memory allocation panics to errors
defer func() {
if r := recover(); r != nil {
debug.PrintStack()
if err, ok := r.(error); ok {
panicErr = err
} else {
@@ -1294,7 +1011,6 @@ func Execute(args []string) error {
server := &Server{
modelPath: *mpath,
status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
}
server.cond = sync.NewCond(&server.mu)

View File

@@ -1,39 +1,36 @@
package harmony
package server
import (
"context"
"encoding/json"
"fmt"
"log/slog"
"maps"
"slices"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/template"
)
type harmonyParserState int
func ShouldUseHarmony(modelFamily string, template *template.Template) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, modelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if template.Contains("<|start|>") && template.Contains("<|end|>") {
return true
}
}
return false
}
const (
harmonyParserState_LookingForMessageStart harmonyParserState = iota
harmonyParserState_ParsingHeader
harmonyParserState_ParsingContent
)
func shouldUseHarmony(model Model) bool {
if slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) {
// heuristic to check whether the template expects to be parsed via harmony:
// search for harmony tags that are nearly always used
if model.Template.Contains("<|start|>") && model.Template.Contains("<|end|>") {
return true
}
}
return false
}
func (s harmonyParserState) String() string {
switch s {
// we're looking for the message start tag
@@ -91,18 +88,17 @@ func (s *HarmonyParser) AddImplicitStart() {
s.acc.WriteString("<|start|>assistant")
}
// AddImplicitStartOrPrefill adds content or thinking to the accumulator else adds start tag
func (s *HarmonyParser) AddImplicitStartOrPrefill(prefillContentOrThinking *bool) {
if prefillContentOrThinking != nil {
if *prefillContentOrThinking {
func (s *HarmonyParser) AddImplicitStartOrPrefill(lastMessage *api.Message) {
if lastMessage != nil && lastMessage.Role == "assistant" {
// handle prefilling conditions
if lastMessage.Content != "" {
s.acc.WriteString("<|start|>assistant<|channel|>final<|message|>")
return
} else {
} else if lastMessage.Thinking != "" {
s.acc.WriteString("<|start|>assistant<|channel|>analysis<|message|>")
return
}
}
s.AddImplicitStart()
}
@@ -279,21 +275,19 @@ const (
// HarmonyMessageHandler processes harmony events and accumulates content appropriately.
// This is a higher level interface that maps harmony concepts into ollama concepts
type HarmonyMessageHandler struct {
state harmonyMessageState
HarmonyParser *HarmonyParser
FunctionNameMap *FunctionNameMap
state harmonyMessageState
harmonyParser *HarmonyParser
}
// NewHarmonyMessageHandler creates a new message handler
func NewHarmonyMessageHandler() *HarmonyMessageHandler {
return &HarmonyMessageHandler{
state: harmonyMessageState_Normal,
HarmonyParser: &HarmonyParser{
harmonyParser: &HarmonyParser{
MessageStartTag: "<|start|>",
MessageEndTag: "<|end|>",
HeaderEndTag: "<|message|>",
},
FunctionNameMap: NewFunctionNameMap(),
}
}
@@ -304,7 +298,7 @@ func (h *HarmonyMessageHandler) AddContent(content string, toolParser *HarmonyTo
thinkingSb := strings.Builder{}
toolContentSb := strings.Builder{}
events := h.HarmonyParser.AddContent(content)
events := h.harmonyParser.AddContent(content)
for _, event := range events {
switch event := event.(type) {
case HarmonyEventHeaderComplete:
@@ -384,129 +378,3 @@ func (a *HarmonyToolCallAccumulator) Drain() (*string, string) {
func (a *HarmonyToolCallAccumulator) Content() string {
return a.acc.String()
}
// FunctionNameMap maps a user-specified function name to a valid function
// name for harmony (which look like TypeScript identifiers). This is needed to
// transform user-specified function names, which might contain characters that
// are not allowed in TypeScript identifiers
type FunctionNameMap struct {
userToHarmony map[string]string
harmonyToUser map[string]string
}
func (m FunctionNameMap) MarshalJSON() ([]byte, error) {
// necessary to avoid exposing map internals
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
return json.Marshal(alias{
UserToHarmony: m.userToHarmony,
HarmonyToUser: m.harmonyToUser,
})
}
func (m *FunctionNameMap) UnmarshalJSON(b []byte) error {
type alias struct {
UserToHarmony map[string]string `json:"userToHarmony"`
HarmonyToUser map[string]string `json:"harmonyToUser"`
}
var a alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
if m.userToHarmony == nil {
m.userToHarmony = make(map[string]string)
}
if m.harmonyToUser == nil {
m.harmonyToUser = make(map[string]string)
}
maps.Copy(m.userToHarmony, a.UserToHarmony)
maps.Copy(m.harmonyToUser, a.HarmonyToUser)
return nil
}
func NewFunctionNameMap() *FunctionNameMap {
return &FunctionNameMap{
userToHarmony: make(map[string]string),
harmonyToUser: make(map[string]string),
}
}
func (m *FunctionNameMap) ConvertAndAdd(userFunctionName string) string {
harmonyFunctionName := m.deriveName(userFunctionName)
m.userToHarmony[userFunctionName] = harmonyFunctionName
m.harmonyToUser[harmonyFunctionName] = userFunctionName
return harmonyFunctionName
}
// OriginalFromConverted looks up the reverse-mapping of a previously-converted
// user->harmony function name. To unmap reliably, the mapping must exist, as
// the conversion process is not reversible without the appropriate state
func (m *FunctionNameMap) OriginalFromConverted(harmonyFunctionName string) string {
if userFunctionName, ok := m.harmonyToUser[harmonyFunctionName]; ok {
return userFunctionName
}
slog.Warn("harmony parser: no reverse mapping found for function name", "harmonyFunctionName", harmonyFunctionName)
// fallback to the original function name if we can't find a mapping
return harmonyFunctionName
}
// convertToValidChars converts a user-specified function name to a valid
// TypeScript identifier.
//
// Limitations:
//
// - This doesn't restrict reserved TypeScript keywords.
// - We don't perform a real ID_Start/ID_Continue check, and instead use the more
// restrictive unicode.IsLetter/unicode.IsDigit check. Unclear what kind of
// identifiers these models were trained on, so in the end we might want to
// convert unicode-heavy identifiers to their closest ASCII equivalents.
func (m *FunctionNameMap) convertToValidChars(userFunctionName string) string {
mapper := func(r rune) rune {
// first, replace certain characters with underscores
if r == ' ' || r == '-' || r == '.' {
return '_'
}
if unicode.IsLetter(r) || unicode.IsDigit(r) || r == '_' || r == '$' {
return r
}
// finally, remove any other characters
return -1
}
candidate := strings.Map(mapper, userFunctionName)
// set a default name if we end up with nothing left
if candidate == "" {
return "unnamed"
}
// if the candidate starts with a number, prepend an underscore to make it a
// valid identifier
if unicode.IsDigit(rune(candidate[0])) {
candidate = "_" + candidate
}
return candidate
}
func (m *FunctionNameMap) deriveName(userFunctionName string) string {
originalCandidate := m.convertToValidChars(userFunctionName)
candidate := originalCandidate
// Check for dupes, and if so, add a number to the end.
// We start at 2 because if we have dupes and the first is never renamed, it
// makes sense for them to be named, say, `f`, `f_2`, `f_3`
count := 2
for {
if _, exists := m.harmonyToUser[candidate]; !exists {
break
}
candidate = fmt.Sprintf("%s_%d", originalCandidate, count)
count++
}
return candidate
}

View File

@@ -1,4 +1,4 @@
package harmony
package server
import (
"fmt"
@@ -467,71 +467,3 @@ func TestHarmonyParserStreaming(t *testing.T) {
})
}
}
// TestFunctionConvertToValidChars tests only FunctionNameMap.convert(), which doesn't
// handle any saving (and therefore no dupe handling)
func TestFunctionConvertToValidChars(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{name: "replace spaces with underscores", in: "get weather", want: "get_weather"},
{name: "replace hyphens with underscores", in: "get-weather", want: "get_weather"},
{name: "replace periods with underscores", in: "get.weather", want: "get_weather"},
{name: "disallow non-word characters", in: "get weather!", want: "get_weather"},
{name: "strip out invalid non-alphanumeric unicode characters", in: "a🫠bc", want: "abc"},
{name: "names that only contain invalid characters", in: "🫠", want: "unnamed"},
{name: "leading number", in: "123", want: "_123"},
{name: "$ allowed", in: "$", want: "$"},
// show that we allow weird unicode letter characters, though we might want
// to convert them to their closest ASCII equivalents in the future
{name: "allow weird unicode letter characters", in: "𝓸𝓵𝓵𝓪𝓶𝓪", want: "𝓸𝓵𝓵𝓪𝓶𝓪"},
// names that look like words but are invalid (i.e., not ID_Start/ID_Continue)
{name: "disallow non-word characters that look like words", in: "ⓞⓛⓛⓐⓜⓐ123", want: "_123"},
}
for i, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
parser := NewFunctionNameMap()
got := parser.convertToValidChars(tt.in)
if got != tt.want {
t.Errorf("case %d: got %q, want %q", i, got, tt.want)
}
})
}
}
func TestFunctionConvertAndAdd(t *testing.T) {
// make a fresh map for each test, but within a test use the same map so we can test for dupe handling
tests := []struct {
name string
in []string
want []string
}{
{name: "basic dupe handling", in: []string{"get weather", "get weather"}, want: []string{"get_weather", "get_weather_2"}},
{name: "dupes from different user-specified names", in: []string{"get weather", "get_weather", "get-weather"}, want: []string{"get_weather", "get_weather_2", "get_weather_3"}},
{name: "non dupes after dupes", in: []string{"get weather", "get_weather", "get-weather", "something-different"}, want: []string{"get_weather", "get_weather_2", "get_weather_3", "something_different"}},
{name: "multiple sets of dupes", in: []string{"a", "a", "b", "a", "a", "b", "a"}, want: []string{"a", "a_2", "b", "a_3", "a_4", "b_2", "a_5"}},
}
for i, tt := range tests {
parser := NewFunctionNameMap()
t.Run(tt.name, func(t *testing.T) {
for j, in := range tt.in {
got := parser.ConvertAndAdd(in)
want := tt.want[j]
if got != want {
t.Errorf("case %d: got %q, want %q", i, got, want)
}
// check that the maps are correct
if parser.userToHarmony[in] != want {
t.Errorf("case %d: userToHarmony[%q] = %q, want %q", i, in, parser.userToHarmony[in], want)
}
if parser.harmonyToUser[want] != in {
t.Errorf("case %d: harmonyToUser[%q] = %q, want %q", i, want, parser.harmonyToUser[want], in)
}
}
})
}
}

View File

@@ -32,7 +32,6 @@ import (
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/harmony"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/openai"
@@ -195,7 +194,14 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return
}
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template) && !req.Raw
useHarmony := shouldUseHarmony(*m) && !req.Raw
var harmonyMessageHandler *HarmonyMessageHandler
var harmonyToolParser *HarmonyToolCallAccumulator
if useHarmony {
harmonyMessageHandler = NewHarmonyMessageHandler()
harmonyMessageHandler.harmonyParser.AddImplicitStart()
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
// Validate Think value: string values currently only allowed for gptoss models
if req.Think != nil && req.Think.IsString() && !useHarmony {
@@ -308,19 +314,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt = b.String()
}
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.DebugTemplateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
DebugInfo: api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
})
return
}
var thinkingState *thinking.Parser
if !useHarmony {
openingTag, closingTag := thinking.InferTags(m.Template.Template)
@@ -356,7 +349,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
},
}
if !useHarmony && thinkingState != nil {
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(cr.Content, harmonyToolParser)
res.Response = content
res.Thinking = thinking
harmonyToolParser.Add(toolContent)
} else if thinkingState != nil {
thinking, content := thinkingState.AddContent(cr.Content)
res.Thinking = thinking
res.Response = content
@@ -367,6 +365,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
if cr.Done {
if useHarmony {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.ToolCalls = append(res.ToolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: *toolName,
Arguments: args,
},
})
}
}
res.DoneReason = cr.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
@@ -1572,58 +1590,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = filterThinkTags(msgs, m)
useHarmony := harmony.ShouldUseHarmony(m.Config.ModelFamily, m.Template)
processedTools := req.Tools
var functionNameMap *harmony.FunctionNameMap
var prefillContentOrThinking *bool
if useHarmony {
functionNameMap = harmony.NewFunctionNameMap()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
// prefill content or thinking flag if the last message is an assistant message
if lastMessage != nil && lastMessage.Role == "assistant" {
if lastMessage.Content != "" {
trueVal := true
// true sets content to be prefilled
prefillContentOrThinking = &trueVal
} else if lastMessage.Thinking != "" {
// false sets thinking to be prefilled
falseVal := false
prefillContentOrThinking = &falseVal
}
}
// make a copy of tools to pass to the chat prompt. Function names may be
// renamed to be valid Harmony function names.
processedTools = make([]api.Tool, len(req.Tools))
copy(processedTools, req.Tools)
for i, tool := range processedTools {
processedTools[i].Function.Name = functionNameMap.ConvertAndAdd(tool.Function.Name)
}
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools, req.Think)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// If debug mode is enabled, return the rendered template instead of calling the model
if req.DebugRenderOnly {
c.JSON(http.StatusOK, api.DebugTemplateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
DebugInfo: api.DebugInfo{
RenderedTemplate: prompt,
ImageCount: len(images),
},
})
return
}
useHarmony := shouldUseHarmony(*m)
// Validate Think value: string values currently only allowed for gptoss models
if req.Think != nil && req.Think.IsString() && !useHarmony {
@@ -1631,6 +1605,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
return
}
var harmonyMessageHandler *HarmonyMessageHandler
var harmonyToolParser *HarmonyToolCallAccumulator
if useHarmony {
harmonyMessageHandler = NewHarmonyMessageHandler()
var lastMessage *api.Message
if len(msgs) > 0 {
lastMessage = &msgs[len(msgs)-1]
}
harmonyMessageHandler.harmonyParser.AddImplicitStartOrPrefill(lastMessage)
harmonyToolParser = harmonyMessageHandler.CreateToolParser()
}
var thinkingState *thinking.Parser
openingTag, closingTag := thinking.InferTags(m.Template.Template)
if req.Think != nil && req.Think.Bool() && openingTag != "" && closingTag != "" {
@@ -1650,17 +1637,15 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
FunctionNameMap: functionNameMap,
PrefillContent: prefillContentOrThinking,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content, Thinking: r.Thinking, ToolCalls: r.ToolCalls},
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
@@ -1676,10 +1661,30 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
if useHarmony {
content, thinking, toolContent := harmonyMessageHandler.AddContent(r.Content, harmonyToolParser)
res.Message.Content = content
res.Message.Thinking = thinking
harmonyToolParser.Add(toolContent)
if r.Done {
toolName, toolContent := harmonyToolParser.Drain()
if toolName != nil {
*toolName = strings.TrimPrefix(*toolName, "functions.")
var args api.ToolCallFunctionArguments
if err := json.Unmarshal([]byte(toolContent), &args); err != nil {
errStr := fmt.Sprintf("error parsing tool call: raw='%s', err=%s", toolContent, err.Error())
ch <- gin.H{"error": errStr}
return
}
res.Message.ToolCalls = []api.ToolCall{{Function: api.ToolCallFunction{Name: *toolName, Arguments: args}}}
}
}
// only send messages with meaningful content (empty messages confuse clients)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || res.Done {
ch <- res
}
return
}

View File

@@ -1,413 +0,0 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
)
func TestGenerateDebugRenderOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ .Prompt }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
tests := []struct {
name string
request api.GenerateRequest
expectDebug bool
expectTemplate string
expectNumImages int
}{
{
name: "debug render only enabled",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello, world!",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "Hello, world!",
},
{
name: "debug render only disabled",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello, world!",
DebugRenderOnly: false,
},
expectDebug: false,
},
{
name: "debug render only with system prompt",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "User question",
System: "You are a helpful assistant",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "User question",
},
{
name: "debug render only with template",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Template: "PROMPT: {{ .Prompt }}",
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "PROMPT: Hello",
},
{
name: "debug render only with images",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Describe this image",
Images: []api.ImageData{[]byte("fake-image-data")},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[img-0]\n\nDescribe this image",
expectNumImages: 1,
},
{
name: "debug render only with raw mode",
request: api.GenerateRequest{
Model: "test-model",
Prompt: "Raw prompt text",
Raw: true,
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "Raw prompt text",
},
}
for _, tt := range tests {
// Test both with and without streaming
streamValues := []bool{false, true}
for _, stream := range streamValues {
streamSuffix := ""
if stream {
streamSuffix = " (streaming)"
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
w := createRequest(t, s.GenerateHandler, req)
if tt.expectDebug {
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
var response api.DebugTemplateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if response.Model != tt.request.Model {
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
}
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
}
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
}
} else {
// When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
})
}
}
}
func TestChatDebugRenderOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: discover.GetGPUInfo,
getCpuFn: discover.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a test model
stream := false
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-model",
Files: map[string]string{"file.gguf": digest},
Template: "{{ if .Tools }}{{ .Tools }}{{ end }}{{ range .Messages }}{{ .Role }}: {{ .Content }}\n{{ end }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
tests := []struct {
name string
request api.ChatRequest
expectDebug bool
expectTemplate string
expectNumImages int
}{
{
name: "chat debug render only enabled",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant"},
{Role: "user", Content: "Hello"},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "system: You are a helpful assistant\nuser: Hello\n",
},
{
name: "chat debug render only disabled",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
},
DebugRenderOnly: false,
},
expectDebug: false,
},
{
name: "chat debug with assistant message",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "Hi there!"},
{Role: "user", Content: "How are you?"},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "user: Hello\nassistant: Hi there!\nuser: How are you?\n",
},
{
name: "chat debug with images",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's in this image?",
Images: []api.ImageData{[]byte("fake-image-data")},
},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "user: [img-0]What's in this image?\n",
expectNumImages: 1,
},
{
name: "chat debug with tools",
request: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{Role: "user", Content: "Get the weather"},
},
Tools: api.Tools{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get weather information",
},
},
},
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[{\"type\":\"function\",\"function\":{\"name\":\"get_weather\",\"description\":\"Get weather information\",\"parameters\":{\"type\":\"\",\"required\":null,\"properties\":null}}}]user: Get the weather\n",
},
}
for _, tt := range tests {
// Test both with and without streaming
streamValues := []bool{false, true}
for _, stream := range streamValues {
streamSuffix := ""
if stream {
streamSuffix = " (streaming)"
}
t.Run(tt.name+streamSuffix, func(t *testing.T) {
req := tt.request
req.Stream = &stream
w := createRequest(t, s.ChatHandler, req)
if tt.expectDebug {
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d, body: %s", http.StatusOK, w.Code, w.Body.String())
}
var response api.DebugTemplateResponse
if err := json.Unmarshal(w.Body.Bytes(), &response); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if response.Model != tt.request.Model {
t.Errorf("expected model %s, got %s", tt.request.Model, response.Model)
}
if tt.expectTemplate != "" && response.DebugInfo.RenderedTemplate != tt.expectTemplate {
t.Errorf("expected template %q, got %q", tt.expectTemplate, response.DebugInfo.RenderedTemplate)
}
if tt.expectNumImages > 0 && response.DebugInfo.ImageCount != tt.expectNumImages {
t.Errorf("expected image count %d, got %d", tt.expectNumImages, response.DebugInfo.ImageCount)
}
} else {
// When debug is disabled, it should attempt normal processing
if w.Code != http.StatusOK {
t.Errorf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
})
}
}
}