Compare commits

..

25 Commits

Author SHA1 Message Date
jmorganca
38ed7c7a4f no parser yet 2025-10-13 14:38:10 -07:00
jmorganca
9ff8e5a64d wip 2025-10-13 13:01:54 -07:00
Jeffrey Morgan
6544e14735 Reapply "add truncate and shift parameters" (#12582) 2025-10-11 16:06:14 -07:00
Devon Rifkin
5db8a818a1 Merge pull request #12581 from ollama/drifkin/renderer-api-generate
routes: fix built-in renderers for `api/generate`
2025-10-11 14:10:23 -07:00
Devon Rifkin
6db8da9958 routes: fix built-in renderers for api/generate
Made it so when api/generate builds up a message array and generates the
prompt it now goes through the same function as `api/chat` for
consistency. This is where we hook the optional built-in renderers to
bypass templates, which was missing for `api/generate` before this
change.

Closes: #12578
2025-10-11 13:57:43 -07:00
frob
0c68ec8d6a discover: fix typo (#12565) 2025-10-11 12:06:02 -07:00
Daniel Hiltgen
70d9e363e1 doc: remove AMD EOL GPUs (#12567) 2025-10-10 17:16:29 -07:00
Michael Yang
1a2feb2a97 ollamarunner: fix deadlock
hardErrCh will deadlock since forwardBatch is blocked on
computeStartedCh which never gets sent. since the response to
hardErrCh is to panic, just panic instead
2025-10-10 16:49:57 -07:00
Daniel Hiltgen
aab2190420 implement nvml for linux (#12517)
* implement nvml for linux

* Improve scheduler logging when VRAM doesn't recover
2025-10-10 15:15:56 -07:00
Michael Yang
629db9dc43 comment split 2025-10-10 13:25:34 -07:00
Michael Yang
e0cd511661 fix test 2025-10-10 13:25:34 -07:00
Michael Yang
207332078f fix lint 2025-10-10 13:25:34 -07:00
Michael Yang
93085127f4 convert: slice gate_up weight 2025-10-10 13:25:34 -07:00
Michael Yang
c00fa9cc2b convert: split gate_up bias 2025-10-10 13:25:34 -07:00
yajianggroup
df411c4b02 refactor: using testing.B.Loop
Signed-off-by: yajianggroup <yajianggroup@outlook.com>
2025-10-10 13:25:29 -07:00
Jeffrey Morgan
3d32249c74 use llama runner for qwen3 (#12556) 2025-10-09 19:08:21 -07:00
Patrick Devine
d681cd7c29 thinking: allow "think": false for non-thinking models (#12555) 2025-10-09 18:46:00 -07:00
shengxinjing
47298fce39 refactor: use builtin max and min 2025-10-09 16:17:52 -07:00
shengxinjing
4a48937ef1 refactor: use builtin max and min 2025-10-09 16:17:52 -07:00
Michael Yang
967a82f52f ollamarunner: measure only active time 2025-10-09 15:44:04 -07:00
Michael Yang
bbbc73d637 llamarunner: update metrics
this change updates how metrics are collected. until now, performance
metrics, specifically initial input processing and subsequent generation
durations, were collected by taking the timestamp when creating a new
sequence, the first token generation, and completing generation. the
processing duration is taken as first token generation sub sequence
creation while generation is taken as completing generation sub first
token generation.

while this approach is an accurate end-to-end metric of processing and
generation, it's not comparable to other tools which only measure the
active, i.e. decode, duration.

this change updates the metrics to only capture decode duration so it
can be more directly compared to other tools
2025-10-09 15:44:04 -07:00
Daniel Hiltgen
15e3611d3d logs: quiet down context canceled on completion and scheduler noise (#12553)
* logs: quiet down context canceled on completion

If the client closes the connection before Completion finishes, we were
logging at error level implying the runner crashed which was misleading.

time=2025-10-08T22:59:20.566-07:00 level=ERROR source=server.go:1490 msg="post predict" error="Post \"http://127.0.0.1:57736/completion\": context canceled"

* quiet down scheduler log error on expected case

Since we don't hold the lock while performing memory load calculations, other
runners can unload in parallel, so finding no runner to unload is a valid scenario
which we shouldn't log at error level.
2025-10-09 10:37:47 -07:00
Parth Sareen
77060d462c routes: structured outputs for gpt-oss (#12460) 2025-10-08 19:13:38 -07:00
Patrick Devine
1b91d4dda1 openai: change the reasonin_effort field to also take none 2025-10-08 18:21:01 -07:00
Jeffrey Morgan
7d965258ce Revert "add truncate and shift parameters (#12519)" (#12545)
This reverts commit 6a62b894c7.
2025-10-08 17:57:57 -07:00
24 changed files with 1668 additions and 269 deletions

View File

@@ -85,6 +85,19 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
case "scales":
mxfp4s[name].scales = t
}
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, slices.Collect(splitDim(t, 1,
split{
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
},
split{
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
},
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
@@ -97,17 +110,28 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name += ".weight"
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
} else if strings.Contains(name, "ffn_gate_up_exps") {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
})
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
}
return out
@@ -158,9 +182,21 @@ func (m *gptossModel) Replacements() []string {
}
type mxfp4 struct {
slices []tensor.Slice
blocks, scales Tensor
}
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
slice[dim] = tensor.S(start, end, step)
return &mxfp4{
slices: slice,
blocks: m.blocks,
scales: m.scales,
}
}
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
var b bytes.Buffer
if _, err := m.blocks.WriteTo(&b); err != nil {
@@ -204,6 +240,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
if len(m.slices) > 0 {
out, err = out.Slice(m.slices...)
if err != nil {
return 0, err
}
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {

View File

@@ -16,7 +16,8 @@ import (
type split struct {
*strings.Replacer
dim int
dim int
slices []tensor.Slice
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
slice := split.slices
if len(slice) == 0 {
slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
}
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))

View File

@@ -408,7 +408,7 @@ func (r *bootstrapRunner) HasExited() bool {
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo {
// TODO DRY out with llm/server.go
slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
start := time.Now()
defer func() {
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)

View File

@@ -51,11 +51,11 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` |
### Windows Support
With ROCm v6.2, the following GPUs are supported on Windows.

View File

@@ -243,7 +243,6 @@ func (kv KV) OllamaEngineRequired() bool {
"gemma3",
"gemma3n",
"mistral3",
"qwen3",
"qwen3moe",
"llama4",
"mllama",

View File

@@ -13,13 +13,13 @@ management libraries for more accurate VRAM usage reporting if available.
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 3 +-
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 172 +++++++++++
8 files changed, 718 insertions(+), 1 deletion(-)
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
8 files changed, 755 insertions(+), 1 deletion(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 0a2dae26..a6bf3378 100644
index 0a2dae26a..a6bf33785 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -169,6 +169,15 @@ extern "C" {
@@ -39,7 +39,7 @@ index 0a2dae26..a6bf3378 100644
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 33b3a15f..86191ef2 100644
index 33b3a15f0..86191ef2c 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -206,6 +206,8 @@ add_library(ggml-base
@@ -52,7 +52,7 @@ index 33b3a15f..86191ef2 100644
target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 531d6e27..3fa3a057 100644
index 531d6e272..3fa3a0575 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -184,7 +184,7 @@ index 531d6e27..3fa3a057 100644
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 06f9e7c1..eb8f66cb 100644
index 06f9e7c1e..eb8f66cb0 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -5,6 +5,9 @@
@@ -206,7 +206,7 @@ index 06f9e7c1..eb8f66cb 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 86a1ebf6..9fc9fbfc 100644
index 86a1ebf62..9fc9fbfcf 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -635,6 +635,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
@@ -225,7 +225,7 @@ index 86a1ebf6..9fc9fbfc 100644
}
#endif
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index 08ab4fc9..17999a61 100644
index 08ab4fc91..17999a616 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
@@ -247,7 +247,7 @@ index 08ab4fc9..17999a61 100644
/* .host_buffer = */ false,
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 00000000..8ef19b8c
index 000000000..8ef19b8cf
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,449 @@
@@ -703,10 +703,10 @@ index 00000000..8ef19b8c
\ No newline at end of file
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
new file mode 100644
index 00000000..aa05e9dc
index 000000000..c9073cef0
--- /dev/null
+++ b/ggml/src/mem_nvml.cpp
@@ -0,0 +1,172 @@
@@ -0,0 +1,209 @@
+// NVIDIA Management Library (NVML)
+//
+// https://developer.nvidia.com/management-library-nvml
@@ -721,6 +721,7 @@ index 00000000..aa05e9dc
+#include "ggml-impl.h"
+#include <filesystem>
+#include <mutex>
+#include <array>
+
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
@@ -787,6 +788,7 @@ index 00000000..aa05e9dc
+ nvmlReturn_t (*nvmlShutdown)(void);
+ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
+ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+ const char * (*nvmlErrorString)(nvmlReturn_t result);
+} nvml { NULL, NULL, NULL, NULL, NULL };
+static std::mutex ggml_nvml_lock;
+
@@ -824,7 +826,8 @@ index 00000000..aa05e9dc
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
+ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
@@ -833,11 +836,45 @@ index 00000000..aa05e9dc
+
+ SetErrorMode(old_mode);
+
+ nvmlReturn_t status = nvml.nvmlInit_v2();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
+ return status;
+ }
+#else
+ // Not currently wired up on Linux
+ return NVML_ERROR_NOT_SUPPORTED;
+ constexpr std::array<const char*, 2> libPaths = {
+ "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
+ "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
+ };
+ for (const char* path : libPaths) {
+ nvml.handle = dlopen(path, RTLD_LAZY);
+ if (nvml.handle) break;
+ }
+ if (nvml.handle == NULL) {
+ GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
+ return NVML_ERROR_NOT_FOUND;
+ }
+ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
+ dlclose(nvml.handle);
+ nvml.handle = NULL;
+ return NVML_ERROR_NOT_FOUND;
+ }
+ nvmlReturn_t status = nvml.nvmlInit_v2();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ dlclose(nvml.handle);
+ nvml.handle = NULL;
+ return status;
+ }
+#endif
+ int status = nvml.nvmlInit_v2();
+ return NVML_SUCCESS;
+}
+
@@ -849,14 +886,14 @@ index 00000000..aa05e9dc
+ }
+ nvmlReturn_enum status = nvml.nvmlShutdown();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
+ GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ }
+#ifdef _WIN32
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
+#else
+ // Not currently wired up on Linux
+ dlclose(nvml.handle);
+#endif
+ nvml.handle = NULL;
+}
+
+int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {

View File

@@ -1488,7 +1488,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
serverReq.Header.Set("Content-Type", "application/json")
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
if err != nil && errors.Is(err, context.Canceled) {
// client closed connection
return err
} else if err != nil {
slog.Error("post predict", "error", err)
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
}
@@ -1500,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("failed reading llm error response: %w", err)
}
log.Printf("llm predict error: %s", bodyBytes)
return api.StatusError{StatusCode: res.StatusCode, Status: res.Status, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
}
scanner := bufio.NewScanner(res.Body)

View File

@@ -12,6 +12,7 @@
#include "ggml-impl.h"
#include <filesystem>
#include <mutex>
#include <array>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
@@ -78,6 +79,7 @@ struct {
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
const char * (*nvmlErrorString)(nvmlReturn_t result);
} nvml { NULL, NULL, NULL, NULL, NULL };
static std::mutex ggml_nvml_lock;
@@ -115,7 +117,8 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
@@ -124,11 +127,45 @@ int ggml_nvml_init() {
SetErrorMode(old_mode);
nvmlReturn_t status = nvml.nvmlInit_v2();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
return status;
}
#else
// Not currently wired up on Linux
return NVML_ERROR_NOT_SUPPORTED;
constexpr std::array<const char*, 2> libPaths = {
"/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
"libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
};
for (const char* path : libPaths) {
nvml.handle = dlopen(path, RTLD_LAZY);
if (nvml.handle) break;
}
if (nvml.handle == NULL) {
GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
return NVML_ERROR_NOT_FOUND;
}
nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
dlclose(nvml.handle);
nvml.handle = NULL;
return NVML_ERROR_NOT_FOUND;
}
nvmlReturn_t status = nvml.nvmlInit_v2();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
dlclose(nvml.handle);
nvml.handle = NULL;
return status;
}
#endif
int status = nvml.nvmlInit_v2();
return NVML_SUCCESS;
}
@@ -140,14 +177,14 @@ void ggml_nvml_release() {
}
nvmlReturn_enum status = nvml.nvmlShutdown();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
}
#ifdef _WIN32
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
#else
// Not currently wired up on Linux
dlclose(nvml.handle);
#endif
nvml.handle = NULL;
}
int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {

View File

@@ -251,7 +251,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
@@ -266,7 +266,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}
b.ResetTimer()
for range b.N {
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
@@ -276,7 +276,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})

View File

@@ -73,7 +73,7 @@ func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []im
for i, res := range possibleResolutions {
scaleW := float64(res.X) / float64(w)
scaleH := float64(res.Y) / float64(h)
scale := math.Min(scaleW, scaleH)
scale := min(scaleW, scaleH)
scales[i] = scale
}
@@ -124,11 +124,11 @@ func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Poi
if scaleW < scaleH {
newRes = image.Point{
targetRes.X,
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
int(min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
}
} else {
newRes = image.Point{
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
int(min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
targetRes.Y,
}
}

View File

@@ -53,7 +53,7 @@ func (p ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point) image.Poi
tw := min(max(imageSize.X, p.imageSize), canvasSize.X)
th := min(max(imageSize.Y, p.imageSize), canvasSize.Y)
r := math.Min(
r := min(
float64(tw)/float64(imageSize.X),
float64(th)/float64(imageSize.Y),
)
@@ -89,10 +89,10 @@ func (p ImageProcessor) optimalTiledCanvas(imageSize image.Point) image.Point {
if minUpscale == 0 {
minUpscale = s
} else {
minUpscale = math.Min(minUpscale, s)
minUpscale = min(minUpscale, s)
}
} else {
maxDownscale = math.Max(maxDownscale, s)
maxDownscale = max(maxDownscale, s)
}
}

219
model/parsers/glm46.go Normal file
View File

@@ -0,0 +1,219 @@
package parsers
import (
"context"
"encoding/json"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
const (
glm46CollectingContent glm46ParserState = iota
CollectingThinkingContent
CollectingToolContent
)
const (
thinkingCloseTag = "</think>"
)
// TODO(gguo): add a field for isThinking
type GLM46Parser struct {
state qwenParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
// TODO(gguo): changes this to reference an objects param
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
// p.state = p.initialState()
return tools
}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var sb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseJSONToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
sb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
sb.WriteString(event.content)
}
}
return sb.String(), "", toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func emitContentBeforeTag(p *GLM46Parser, events []glm46Event, tag string) []glm46Event {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
return events
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
events = emitContentBeforeTag(p, events, toolOpenTag)
p.state = glm46CollectingToolContent
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 { // why does qwen3coder not have this here
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case CollectingToolContent:
if strings.Contains(p.buffer.String(), glm46ToolCloseTag) {
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
events = append(events, glm46EventRawToolCall{raw: before})
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = glm46CollectingContent
return events, true
} else {
return events, false
}
case glm46CollectingThinkingContent: // so we want to hip the unambiguous stuff
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
if len(before) > 0 {
events = append(events, glm46EventThinkingContent{content: before})
}
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = glm46CollectingContent
return events, true
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
default:
panic("unreachable")
}
}
func parseJSONToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var toolCallFunction api.ToolCallFunction
if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil {
return api.ToolCall{}, err
}
toolCall := api.ToolCall{}
toolCall.Function = toolCallFunction
return toolCall, nil
}

View File

@@ -21,6 +21,9 @@ func ParserForName(name string) Parser {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
case "glm-4.6":
parser := &GLM46Parser{}
return parser
case "passthrough":
return &PassthroughParser{}
case "harmony":

View File

@@ -0,0 +1,239 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo, Japan",
"unit": "celsius",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Japan",
"unit": "fahrenheit",
},
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := GLM46Renderer(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

109
model/renderers/gml46.go Normal file
View File

@@ -0,0 +1,109 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
func GLM46Renderer(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
fmt.Println("thinkValue", thinkValue, thinkValue.Bool())
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>")
}
return sb.String(), nil
}

View File

@@ -20,6 +20,8 @@ func rendererForName(name string) rendererFunc {
switch name {
case "qwen3-coder":
return Qwen3CoderRenderer
case "glm-4.6":
return GLM46Renderer
default:
return nil
}

View File

@@ -567,18 +567,24 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
var think *api.ThinkValue
var effort string
if r.Reasoning != nil {
if !slices.Contains([]string{"high", "medium", "low", "none"}, r.Reasoning.Effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", r.Reasoning.Effort)
effort = r.Reasoning.Effort
} else if r.ReasoningEffort != nil {
effort = *r.ReasoningEffort
}
if effort != "" {
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
}
if r.Reasoning.Effort == "none" {
if effort == "none" {
think = &api.ThinkValue{Value: false}
} else {
think = &api.ThinkValue{Value: r.Reasoning.Effort}
think = &api.ThinkValue{Value: effort}
}
} else if r.ReasoningEffort != nil {
think = &api.ThinkValue{Value: *r.ReasoningEffort}
}
return &api.ChatRequest{

View File

@@ -85,10 +85,10 @@ type Sequence struct {
doneReason llm.DoneReason
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
numDecoded int
numPromptInputs int
processingDuration time.Duration
generationDuration time.Duration
numDecoded int
numPromptInputs int
}
type NewSequenceParams struct {
@@ -106,8 +106,6 @@ var errorInputTooLong = errors.New("the input length exceeds the context length"
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
inputs, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
@@ -153,18 +151,17 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
}
return &Sequence{
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
inputs: inputs,
numPromptInputs: len(inputs),
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
samplingCtx: sc,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
}, nil
}
@@ -454,8 +451,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
return nil
}
err := s.lc.Decode(batch)
if err != nil {
t := time.Now()
if err := s.lc.Decode(batch); err != nil {
return fmt.Errorf("failed to decode batch: %w", err)
}
@@ -475,9 +472,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
continue
}
seq.numDecoded += 1
if seq.numDecoded == 1 {
seq.startGenerationTime = time.Now()
s.lc.Synchronize()
seq.numDecoded++
if seq.numDecoded > 1 {
seq.generationDuration += time.Since(t)
} else {
seq.processingDuration += time.Since(t)
}
// if done processing the prompt, generate an embedding and return
@@ -668,9 +668,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
PromptEvalDuration: seq.processingDuration,
EvalCount: seq.numDecoded,
EvalDuration: time.Since(seq.startGenerationTime),
EvalDuration: seq.generationDuration,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}

View File

@@ -94,10 +94,11 @@ type Sequence struct {
doneReason llm.DoneReason
// Metrics
startProcessingTime time.Time
startGenerationTime time.Time
numPredicted int
numPromptInputs int
startedAt, lastUpdatedAt time.Time
processingDuration time.Duration
samplingDuration time.Duration
numPredicted int
numPromptInputs int
}
type NewSequenceParams struct {
@@ -115,8 +116,6 @@ var errorInputTooLong = errors.New("the input length exceeds the context length"
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
startTime := time.Now()
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
if err != nil {
return nil, fmt.Errorf("failed to process inputs: %w", err)
@@ -176,21 +175,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
// TODO(jessegross): Ingest cached history for grammar
return &Sequence{
ctxs: ctxs,
mmStore: mmStore,
inputs: inputs,
numPromptInputs: len(inputs),
startProcessingTime: startTime,
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shift: params.shift,
ctxs: ctxs,
mmStore: mmStore,
inputs: inputs,
numPromptInputs: len(inputs),
numPredict: params.numPredict,
pendingResponses: make([]string, 0),
responses: make(chan string, 100),
quit: make(chan bool, 1),
embedding: make(chan []float32, 1),
sampler: params.sampler,
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shift: params.shift,
}, nil
}
@@ -336,9 +334,6 @@ type Server struct {
// TODO (jmorganca): make this n_batch
batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches
batchID int
@@ -421,25 +416,25 @@ func (s *Server) run(ctx context.Context) {
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
var activeBatch batchState
var previousBatch batchState
for {
select {
case <-ctx.Done():
return
case err := <-s.hardErrCh:
panic(err)
default:
var err error
activeBatch, err = s.forwardBatch(activeBatch)
nextBatch, err := s.forwardBatch(previousBatch)
if err != nil {
panic(err)
}
if supportsAsync {
go s.computeBatch(activeBatch)
go s.computeBatch(nextBatch)
} else {
s.computeBatch(activeBatch)
s.computeBatch(nextBatch)
}
previousBatch = nextBatch
}
}
}
@@ -581,6 +576,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
seq.inputs = seq.inputs[len(seq.pendingInputs):]
}
startedAt := time.Now()
for i := range nextBatch.seqs {
if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
nextBatch.seqs[i].startedAt = startedAt
}
}
if resumeSeq != -1 {
s.nextSeq = resumeSeq
} else {
@@ -675,9 +677,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// don't sample prompt processing
if len(seq.inputs) != 0 {
if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
panic("caching disabled but unable to fit entire input in a batch")
}
continue
}
@@ -701,6 +701,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
activeBatch.modelOutput)
outputs := activeBatch.modelOutput.Floats()
t := time.Now()
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
@@ -713,8 +714,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
continue
}
seq.lastUpdatedAt = t
if seq.numPredicted == 1 {
seq.startGenerationTime = time.Now()
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
seq.startedAt = seq.lastUpdatedAt
}
// if done processing the prompt, generate an embedding and return
@@ -729,8 +732,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
panic("failed to sample token")
}
nextBatchTokens[i].Token = token
@@ -747,8 +749,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
panic("failed to decode token")
}
seq.pendingResponses = append(seq.pendingResponses, piece)
@@ -793,6 +794,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
s.removeSequence(i, llm.DoneReasonConnectionClosed)
}
}
samplingDuration := time.Since(t)
for i, seq := range s.seqs {
if seq != nil && nextBatchTokens[i] != nil {
s.seqs[i].samplingDuration += samplingDuration
}
}
}
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
@@ -912,9 +920,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
Done: true,
DoneReason: seq.doneReason,
PromptEvalCount: seq.numPromptInputs,
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
PromptEvalDuration: seq.processingDuration,
EvalCount: seq.numPredicted,
EvalDuration: time.Since(seq.startGenerationTime),
EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
}); err != nil {
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
}
@@ -1329,7 +1337,6 @@ func Execute(args []string) error {
server := &Server{
modelPath: *mpath,
status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
}
server.cond = sync.NewCond(&server.mu)

View File

@@ -332,14 +332,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
}
modelCaps := m.Capabilities()
if req.Think != nil {
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
if req.Think == nil {
req.Think = &api.ThinkValue{Value: true}
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -401,12 +403,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, m.Messages...)
}
userMsg := api.Message{Role: "user", Content: req.Prompt}
for _, i := range images {
imgPrompt := ""
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
userMsg.Images = append(userMsg.Images, i.Data)
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
values.Messages = append(msgs, userMsg)
}
values.Think = req.Think != nil && req.Think.Bool()
@@ -427,12 +428,31 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s)
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// check that we're in the `api/chat`-like flow, and if so, generate the
// prompt the same way
// TEMP(drifkin): we should really just detect the chat-like flow and call
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
if req.Context != nil {
b.WriteString(prompt)
prompt = b.String()
}
} else {
// legacy flow
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
prompt = b.String()
}
}
// If debug mode is enabled, return the rendered template instead of calling the model
@@ -535,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch <- res
}); err != nil {
ch <- err
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
}
}()
@@ -549,20 +574,18 @@ func (s *Server) GenerateHandler(c *gin.Context) {
sbThinking.WriteString(t.Thinking)
sbContent.WriteString(t.Response)
r = t
case api.StatusError:
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
return
case error:
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
return
// TODO (jmorganca): remove use of gin.H here and instead expect
// api.StatusError to be send in the channel
case gin.H:
if msg, ok := t["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
@@ -1627,16 +1650,28 @@ func streamResponse(c *gin.Context, ch chan any) {
return false
}
if statusError, ok := val.(api.StatusError); ok {
c.Header("Content-Type", "application/json")
c.AbortWithStatusJSON(statusError.StatusCode, gin.H{"error": statusError.ErrorMessage})
return false
}
// errors are provided as a gin.H with an "error" field and
// an optional "status" field. For errors that are streamed
// before any content, we need to set the status code and
// content type for the error.
if h, ok := val.(gin.H); ok {
if e, ok := h["error"].(string); ok {
status, ok := h["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if err, ok := val.(error); ok {
c.Header("Content-Type", "application/json")
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return false
if !c.Writer.Written() {
c.Header("Content-Type", "application/json")
c.JSON(status, gin.H{"error": e})
} else {
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
slog.Error("streamResponse failed to encode json error", "error", err)
}
}
return false
}
}
bts, err := json.Marshal(val)
@@ -1898,14 +1933,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
modelCaps := m.Capabilities()
if req.Think != nil {
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
if req.Think == nil {
req.Think = &api.ThinkValue{Value: true}
}
} else {
if req.Think != nil && req.Think.Bool() {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
return
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -2001,90 +2038,174 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
type structuredOutputsState int
const (
structuredOutputsState_None structuredOutputsState = iota
structuredOutputsState_ReadyToApply
structuredOutputsState_Applying
)
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
structuredOutputsState := structuredOutputsState_None
for {
var tb strings.Builder
currentFormat := req.Format
// structured outputs via double request is enabled when:
// 1. the model supports the thinking capability and
// 2. it uses a built-in parser or our generic thinking parser
// Note that the current approach does not work for (potential future)
// non-thinking models that emit anything before actual content. This
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
}
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
res.Message.Content = content
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
tb.WriteString(thinking)
// we are now receiving content from the model - we should start applying structured outputs
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
cancel()
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Thinking = thinkingContent
tb.WriteString(thinkingContent)
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
// to avoid leaking mixed tokens like "</think>Hello"
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = ""
ch <- res
cancel()
return
}
res.Message.Content = remainingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
})
if err != nil {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
return
}
}
ch <- res
}); err != nil {
ch <- err
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
if structuredOutputsState == structuredOutputsState_ReadyToApply {
structuredOutputsState = structuredOutputsState_Applying
msg := api.Message{
Role: "assistant",
Thinking: tb.String(),
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
return
}
// force constraining by terminating thinking header, the parser is already at this state
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
// model continue thinking or ending thinking and outputting the final message.
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
}
continue
}
break
}
}()
@@ -2102,20 +2223,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
toolCalls = append(toolCalls, t.Message.ToolCalls...)
}
case api.StatusError:
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
return
case error:
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
return
// TODO (jmorganca): remove use of gin.H here and instead expect
// api.StatusError to be send in the channel
case gin.H:
if msg, ok := t["error"].(string); ok {
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
return
msg, ok := t["error"].(string)
if !ok {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})

View File

@@ -146,7 +146,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[img-0]\n\nDescribe this image",
expectTemplate: "[img-0]Describe this image",
expectNumImages: 1,
},
{

View File

@@ -0,0 +1,313 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
)
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
// when in chat-like flow (messages present, no suffix, no template)
func TestGenerateWithBuiltinRenderer(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer (qwen3-coder)
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Create a model with the qwen3-coder renderer
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
mock.CompletionResponse.Content = "Hi!"
t.Run("chat-like flow uses renderer", func(t *testing.T) {
// Test that when using messages (chat-like flow), the built-in renderer is used
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
// When messages are built internally from prompt, it should use the renderer
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
// Test that system messages work with the renderer
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
System: "You are a helpful coding assistant.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should contain the system message and use renderer format
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("custom template bypasses renderer", func(t *testing.T) {
// Test that providing a custom template uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Template: "{{ .Prompt }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when custom template is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should just be the raw prompt from the template
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
// Create a model with suffix support for the next test
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-suffix-renderer",
From: "test-renderer",
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("suffix bypasses renderer", func(t *testing.T) {
// Test that providing a suffix uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix-renderer",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when suffix is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should use the suffix template format
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
func TestGenerateWithDebugRenderOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-debug-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("debug_render_only with renderer", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-debug-renderer",
Prompt: "Write a hello world function",
System: "You are a coding assistant",
DebugRenderOnly: true,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.DebugInfo == nil {
t.Fatalf("expected debug info, got nil")
}
// Verify that the rendered template uses the built-in renderer
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
}
})
}

View File

@@ -158,11 +158,26 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
if diff := cmp.Diff(w.Body.String(), `{"error":"\"test\" does not support thinking"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("model can't think but think set false", func(t *testing.T) {
think := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Think: &api.ThinkValue{Value: think},
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
@@ -1292,4 +1307,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
}
})
t.Run("structured outputs restart non-stream", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
}
if resp.Message.Content != `{"answer":"42"}` {
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
}
if !resp.Done {
t.Errorf("expected response to be done")
}
if resp.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
}
})
t.Run("structured outputs restart streaming", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
decoder := json.NewDecoder(w.Body)
var events []api.ChatResponse
for {
var event api.ChatResponse
if err := decoder.Decode(&event); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
events = append(events, event)
if event.Done {
break
}
}
if len(events) < 2 {
t.Fatalf("expected at least two streaming events, got %d", len(events))
}
first := events[0]
if first.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
}
if first.Message.Content != "" {
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
}
if first.Done {
t.Error("expected first event to be non-terminal")
}
last := events[len(events)-1]
if last.Message.Thinking != "" {
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
}
if last.Message.Content != `{"answer":"42"}` {
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
}
if !last.Done {
t.Error("expected final event to be done")
}
if last.DoneReason != "stop" {
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
}
})
}

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
@@ -229,8 +230,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
}
if runnerToExpire == nil {
// Shouildn't happen
slog.Error("runner to expire was nil!")
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
slog.Debug("runner to expire was nil, retrying")
continue
}
// Trigger an expiration to unload once it's done
@@ -644,27 +646,35 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
totalMemoryBefore += gpu.TotalMemory
freeMemoryBefore += gpu.FreeMemory
}
totalMemoryNow := totalMemoryBefore
freeMemoryNow := freeMemoryBefore
go func() {
expiresAt := start.Add(5 * time.Second) // typical convergence is 0.5-1.5s
// typical convergence is 0.5-1.5s - If it takes more than 5 seconds to discover and converge, let the scheduler estimate VRAM usage
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
<-ticker.C
if time.Now().After(expiresAt) {
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
finished <- struct{}{}
}
// Query GPUs, look for free to go back up
gpusNow := s.getGpuFn(context.Background(), runners)
var totalMemoryNow, freeMemoryNow uint64
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
freeMemoryNow += gpu.FreeMemory
}
// If we're within ~80% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
select {
case <-ticker.C:
// Query GPUs, look for free to go back up
gpusNow := s.getGpuFn(ctx, runners)
totalMemoryNow = 0
freeMemoryNow = 0
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
freeMemoryNow += gpu.FreeMemory
}
logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
// If we're within ~75% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}
case <-ctx.Done():
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}