Compare commits

..

15 Commits

Author SHA1 Message Date
jmorganca
38ed7c7a4f no parser yet 2025-10-13 14:38:10 -07:00
jmorganca
9ff8e5a64d wip 2025-10-13 13:01:54 -07:00
Jeffrey Morgan
6544e14735 Reapply "add truncate and shift parameters" (#12582) 2025-10-11 16:06:14 -07:00
Devon Rifkin
5db8a818a1 Merge pull request #12581 from ollama/drifkin/renderer-api-generate
routes: fix built-in renderers for `api/generate`
2025-10-11 14:10:23 -07:00
Devon Rifkin
6db8da9958 routes: fix built-in renderers for api/generate
Made it so when api/generate builds up a message array and generates the
prompt it now goes through the same function as `api/chat` for
consistency. This is where we hook the optional built-in renderers to
bypass templates, which was missing for `api/generate` before this
change.

Closes: #12578
2025-10-11 13:57:43 -07:00
frob
0c68ec8d6a discover: fix typo (#12565) 2025-10-11 12:06:02 -07:00
Daniel Hiltgen
70d9e363e1 doc: remove AMD EOL GPUs (#12567) 2025-10-10 17:16:29 -07:00
Michael Yang
1a2feb2a97 ollamarunner: fix deadlock
hardErrCh will deadlock since forwardBatch is blocked on
computeStartedCh which never gets sent. since the response to
hardErrCh is to panic, just panic instead
2025-10-10 16:49:57 -07:00
Daniel Hiltgen
aab2190420 implement nvml for linux (#12517)
* implement nvml for linux

* Improve scheduler logging when VRAM doesn't recover
2025-10-10 15:15:56 -07:00
Michael Yang
629db9dc43 comment split 2025-10-10 13:25:34 -07:00
Michael Yang
e0cd511661 fix test 2025-10-10 13:25:34 -07:00
Michael Yang
207332078f fix lint 2025-10-10 13:25:34 -07:00
Michael Yang
93085127f4 convert: slice gate_up weight 2025-10-10 13:25:34 -07:00
Michael Yang
c00fa9cc2b convert: split gate_up bias 2025-10-10 13:25:34 -07:00
yajianggroup
df411c4b02 refactor: using testing.B.Loop
Signed-off-by: yajianggroup <yajianggroup@outlook.com>
2025-10-10 13:25:29 -07:00
23 changed files with 1408 additions and 144 deletions

View File

@@ -106,6 +106,14 @@ type GenerateRequest struct {
// before this option was introduced)
Think *ThinkValue `json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate *bool `json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift *bool `json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
@@ -140,6 +148,14 @@ type ChatRequest struct {
// for supported models.
Think *ThinkValue `json:"think,omitempty"`
// Truncate is a boolean that, when set to true, truncates the chat history messages
// if the rendered prompt exceeds the context length limit.
Truncate *bool `json:"truncate,omitempty"`
// Shift is a boolean that, when set to true, shifts the chat history
// when hitting the context length limit instead of erroring.
Shift *bool `json:"shift,omitempty"`
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`

View File

@@ -85,6 +85,19 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
case "scales":
mxfp4s[name].scales = t
}
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, slices.Collect(splitDim(t, 1,
split{
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
},
split{
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
},
))...)
} else {
out = append(out, &ggml.Tensor{
Name: t.Name(),
@@ -97,17 +110,28 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name += ".weight"
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
} else if strings.Contains(name, "ffn_gate_up_exps") {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
})
}
out = append(out, &ggml.Tensor{
Name: name,
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
})
}
return out
@@ -158,9 +182,21 @@ func (m *gptossModel) Replacements() []string {
}
type mxfp4 struct {
slices []tensor.Slice
blocks, scales Tensor
}
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
slice[dim] = tensor.S(start, end, step)
return &mxfp4{
slices: slice,
blocks: m.blocks,
scales: m.scales,
}
}
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
var b bytes.Buffer
if _, err := m.blocks.WriteTo(&b); err != nil {
@@ -204,6 +240,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
return 0, err
}
if len(m.slices) > 0 {
out, err = out.Slice(m.slices...)
if err != nil {
return 0, err
}
}
out = tensor.Materialize(out)
if err := out.Reshape(out.Shape().TotalSize()); err != nil {

View File

@@ -16,7 +16,8 @@ import (
type split struct {
*strings.Replacer
dim int
dim int
slices []tensor.Slice
// fn is an optional function to apply to the tensor after slicing
fn func(tensor.Tensor) (tensor.Tensor, error)
@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
shape := slices.Clone(t.Shape())
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
slice := split.slices
if len(slice) == 0 {
slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
offset += int(shape[dim])
}
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
dims := make([]int, len(shape))

View File

@@ -408,7 +408,7 @@ func (r *bootstrapRunner) HasExited() bool {
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo {
// TODO DRY out with llm/server.go
slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
start := time.Now()
defer func() {
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)

View File

@@ -51,11 +51,11 @@ sudo modprobe nvidia_uvm`
Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` |
### Windows Support
With ROCm v6.2, the following GPUs are supported on Windows.

View File

@@ -13,13 +13,13 @@ management libraries for more accurate VRAM usage reporting if available.
ggml/src/ggml-impl.h | 8 +
ggml/src/ggml-metal/ggml-metal.cpp | 3 +-
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
ggml/src/mem_nvml.cpp | 172 +++++++++++
8 files changed, 718 insertions(+), 1 deletion(-)
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
8 files changed, 755 insertions(+), 1 deletion(-)
create mode 100644 ggml/src/mem_hip.cpp
create mode 100644 ggml/src/mem_nvml.cpp
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
index 0a2dae26..a6bf3378 100644
index 0a2dae26a..a6bf33785 100644
--- a/ggml/include/ggml-backend.h
+++ b/ggml/include/ggml-backend.h
@@ -169,6 +169,15 @@ extern "C" {
@@ -39,7 +39,7 @@ index 0a2dae26..a6bf3378 100644
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
index 33b3a15f..86191ef2 100644
index 33b3a15f0..86191ef2c 100644
--- a/ggml/src/CMakeLists.txt
+++ b/ggml/src/CMakeLists.txt
@@ -206,6 +206,8 @@ add_library(ggml-base
@@ -52,7 +52,7 @@ index 33b3a15f..86191ef2 100644
target_include_directories(ggml-base PRIVATE .)
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
index 531d6e27..3fa3a057 100644
index 531d6e272..3fa3a0575 100644
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
@@ -184,7 +184,7 @@ index 531d6e27..3fa3a057 100644
/* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg,
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
index 06f9e7c1..eb8f66cb 100644
index 06f9e7c1e..eb8f66cb0 100644
--- a/ggml/src/ggml-cuda/vendors/hip.h
+++ b/ggml/src/ggml-cuda/vendors/hip.h
@@ -5,6 +5,9 @@
@@ -206,7 +206,7 @@ index 06f9e7c1..eb8f66cb 100644
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
index 86a1ebf6..9fc9fbfc 100644
index 86a1ebf62..9fc9fbfcf 100644
--- a/ggml/src/ggml-impl.h
+++ b/ggml/src/ggml-impl.h
@@ -635,6 +635,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
@@ -225,7 +225,7 @@ index 86a1ebf6..9fc9fbfc 100644
}
#endif
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
index 08ab4fc9..17999a61 100644
index 08ab4fc91..17999a616 100644
--- a/ggml/src/ggml-metal/ggml-metal.cpp
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
@@ -247,7 +247,7 @@ index 08ab4fc9..17999a61 100644
/* .host_buffer = */ false,
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
new file mode 100644
index 00000000..8ef19b8c
index 000000000..8ef19b8cf
--- /dev/null
+++ b/ggml/src/mem_hip.cpp
@@ -0,0 +1,449 @@
@@ -703,10 +703,10 @@ index 00000000..8ef19b8c
\ No newline at end of file
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
new file mode 100644
index 00000000..aa05e9dc
index 000000000..c9073cef0
--- /dev/null
+++ b/ggml/src/mem_nvml.cpp
@@ -0,0 +1,172 @@
@@ -0,0 +1,209 @@
+// NVIDIA Management Library (NVML)
+//
+// https://developer.nvidia.com/management-library-nvml
@@ -721,6 +721,7 @@ index 00000000..aa05e9dc
+#include "ggml-impl.h"
+#include <filesystem>
+#include <mutex>
+#include <array>
+
+#ifdef _WIN32
+# define WIN32_LEAN_AND_MEAN
@@ -787,6 +788,7 @@ index 00000000..aa05e9dc
+ nvmlReturn_t (*nvmlShutdown)(void);
+ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
+ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
+ const char * (*nvmlErrorString)(nvmlReturn_t result);
+} nvml { NULL, NULL, NULL, NULL, NULL };
+static std::mutex ggml_nvml_lock;
+
@@ -824,7 +826,8 @@ index 00000000..aa05e9dc
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
+ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
@@ -833,11 +836,45 @@ index 00000000..aa05e9dc
+
+ SetErrorMode(old_mode);
+
+ nvmlReturn_t status = nvml.nvmlInit_v2();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
+ return status;
+ }
+#else
+ // Not currently wired up on Linux
+ return NVML_ERROR_NOT_SUPPORTED;
+ constexpr std::array<const char*, 2> libPaths = {
+ "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
+ "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
+ };
+ for (const char* path : libPaths) {
+ nvml.handle = dlopen(path, RTLD_LAZY);
+ if (nvml.handle) break;
+ }
+ if (nvml.handle == NULL) {
+ GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
+ return NVML_ERROR_NOT_FOUND;
+ }
+ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
+ GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
+ dlclose(nvml.handle);
+ nvml.handle = NULL;
+ return NVML_ERROR_NOT_FOUND;
+ }
+ nvmlReturn_t status = nvml.nvmlInit_v2();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ dlclose(nvml.handle);
+ nvml.handle = NULL;
+ return status;
+ }
+#endif
+ int status = nvml.nvmlInit_v2();
+ return NVML_SUCCESS;
+}
+
@@ -849,14 +886,14 @@ index 00000000..aa05e9dc
+ }
+ nvmlReturn_enum status = nvml.nvmlShutdown();
+ if (status != NVML_SUCCESS) {
+ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
+ GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
+ }
+#ifdef _WIN32
+ FreeLibrary((HMODULE)(nvml.handle));
+ nvml.handle = NULL;
+#else
+ // Not currently wired up on Linux
+ dlclose(nvml.handle);
+#endif
+ nvml.handle = NULL;
+}
+
+int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {

View File

@@ -1379,7 +1379,9 @@ type CompletionRequest struct {
Images []ImageData
Options *api.Options
Grammar string // set before sending the request to the subprocess
Grammar string // set before sending the request to the subprocess
Shift bool
Truncate bool
}
// DoneReason represents the reason why a completion response is done
@@ -1501,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return fmt.Errorf("failed reading llm error response: %w", err)
}
log.Printf("llm predict error: %s", bodyBytes)
return fmt.Errorf("%s", bodyBytes)
return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
}
scanner := bufio.NewScanner(res.Body)

View File

@@ -12,6 +12,7 @@
#include "ggml-impl.h"
#include <filesystem>
#include <mutex>
#include <array>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
@@ -78,6 +79,7 @@ struct {
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
const char * (*nvmlErrorString)(nvmlReturn_t result);
} nvml { NULL, NULL, NULL, NULL, NULL };
static std::mutex ggml_nvml_lock;
@@ -115,7 +117,8 @@ int ggml_nvml_init() {
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
@@ -124,11 +127,45 @@ int ggml_nvml_init() {
SetErrorMode(old_mode);
nvmlReturn_t status = nvml.nvmlInit_v2();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
return status;
}
#else
// Not currently wired up on Linux
return NVML_ERROR_NOT_SUPPORTED;
constexpr std::array<const char*, 2> libPaths = {
"/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
"libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
};
for (const char* path : libPaths) {
nvml.handle = dlopen(path, RTLD_LAZY);
if (nvml.handle) break;
}
if (nvml.handle == NULL) {
GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
return NVML_ERROR_NOT_FOUND;
}
nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
dlclose(nvml.handle);
nvml.handle = NULL;
return NVML_ERROR_NOT_FOUND;
}
nvmlReturn_t status = nvml.nvmlInit_v2();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
dlclose(nvml.handle);
nvml.handle = NULL;
return status;
}
#endif
int status = nvml.nvmlInit_v2();
return NVML_SUCCESS;
}
@@ -140,14 +177,14 @@ void ggml_nvml_release() {
}
nvmlReturn_enum status = nvml.nvmlShutdown();
if (status != NVML_SUCCESS) {
GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
}
#ifdef _WIN32
FreeLibrary((HMODULE)(nvml.handle));
nvml.handle = NULL;
#else
// Not currently wired up on Linux
dlclose(nvml.handle);
#endif
nvml.handle = NULL;
}
int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {

View File

@@ -251,7 +251,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
bts := bts[:n]
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
for b.Loop() {
_, err := tokenizer.Encode(string(bts), true)
if err != nil {
b.Fatal(err)
@@ -266,7 +266,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
}
b.ResetTimer()
for range b.N {
for b.Loop() {
_, err := tokenizer.Decode(ids)
if err != nil {
b.Fatal(err)
@@ -276,7 +276,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
b.ResetTimer()
for range b.N {
for b.Loop() {
slices.Collect(tokenizer.split(string(bts)))
}
})

219
model/parsers/glm46.go Normal file
View File

@@ -0,0 +1,219 @@
package parsers
import (
"context"
"encoding/json"
"log/slog"
"strings"
"unicode"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/logutil"
)
const (
glm46CollectingContent glm46ParserState = iota
CollectingThinkingContent
CollectingToolContent
)
const (
thinkingCloseTag = "</think>"
)
// TODO(gguo): add a field for isThinking
type GLM46Parser struct {
state qwenParserState
buffer strings.Builder
tools []api.Tool
}
func (p *GLM46Parser) HasToolSupport() bool {
return true
}
// TODO(gguo): changes this to reference an objects param
func (p *GLM46Parser) HasThinkingSupport() bool {
return true
}
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
p.tools = tools
// p.state = p.initialState()
return tools
}
type glm46EventThinkingContent struct {
content string
}
func (glm46EventThinkingContent) isGLM46Event() {}
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
p.buffer.WriteString(s)
events := p.parseEvents()
var toolCalls []api.ToolCall
var sb strings.Builder
for _, event := range events {
switch event := event.(type) {
case glm46EventRawToolCall:
toolCall, err := parseJSONToolCall(event, p.tools)
if err != nil {
slog.Warn("qwen tool call parsing failed", "error", err)
return "", "", nil, err
}
toolCalls = append(toolCalls, toolCall)
case glm46EventThinkingContent:
sb.WriteString(event.content)
case glm46EventContent:
// TODO(drifkin): if the same turn contains multiple interleaved content
// events, we naively append them together here.
sb.WriteString(event.content)
}
}
return sb.String(), "", toolCalls, nil
}
func (p *GLM46Parser) parseEvents() []glm46Event {
var all []glm46Event
keepLooping := true
for keepLooping {
var events []glm46Event
events, keepLooping = p.eat()
if len(events) > 0 {
all = append(all, events...)
}
}
if len(all) > 0 {
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
}
return all
}
func emitContentBeforeTag(p *GLM46Parser, events []glm46Event, tag string) []glm46Event {
split := strings.SplitN(p.buffer.String(), tag, 2)
before := split[0]
before = strings.TrimRightFunc(before, unicode.IsSpace)
if len(before) > 0 {
events = append(events, glm46EventContent{content: before})
}
after := split[1]
p.buffer.Reset()
p.buffer.WriteString(after)
return events
}
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
var events []glm46Event
switch p.state {
case glm46CollectingContent:
if strings.Contains(p.buffer.String(), toolOpenTag) {
events = emitContentBeforeTag(p, events, toolOpenTag)
p.state = glm46CollectingToolContent
return events, true
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 { // why does qwen3coder not have this here
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventContent{content: unambiguous})
}
return events, false
}
case CollectingToolContent:
if strings.Contains(p.buffer.String(), glm46ToolCloseTag) {
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
events = append(events, glm46EventRawToolCall{raw: before})
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = glm46CollectingContent
return events, true
} else {
return events, false
}
case glm46CollectingThinkingContent: // so we want to hip the unambiguous stuff
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
before := split[0]
if len(before) == 0 {
slog.Warn("qwen tool call closing tag found but no content before it")
}
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
if len(before) > 0 {
events = append(events, glm46EventThinkingContent{content: before})
}
p.buffer.Reset()
p.buffer.WriteString(after)
p.state = glm46CollectingContent
return events, true
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
} else {
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
ambiguousStart := len(p.buffer.String()) - whitespaceLen
unambiguous := p.buffer.String()[:ambiguousStart]
ambiguous := p.buffer.String()[ambiguousStart:]
p.buffer.Reset()
p.buffer.WriteString(ambiguous)
if len(unambiguous) > 0 {
events = append(events, glm46EventThinkingContent{content: unambiguous})
}
return events, false
}
default:
panic("unreachable")
}
}
func parseJSONToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
var toolCallFunction api.ToolCallFunction
if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil {
return api.ToolCall{}, err
}
toolCall := api.ToolCall{}
toolCall.Function = toolCallFunction
return toolCall, nil
}

View File

@@ -21,6 +21,9 @@ func ParserForName(name string) Parser {
case "qwen3-coder":
parser := &Qwen3CoderParser{}
return parser
case "glm-4.6":
parser := &GLM46Parser{}
return parser
case "passthrough":
return &PassthroughParser{}
case "harmony":

View File

@@ -0,0 +1,239 @@
package renderers
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
)
func TestGLM46Renderer(t *testing.T) {
tests := []struct {
name string
messages []api.Message
tools []api.Tool
thinkValue *api.ThinkValue
expected string
}{
{
name: "basic",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with system message",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant."},
{Role: "user", Content: "Hello, how are you?"},
},
expected: `[gMASK]<sop><|system|>
You are a helpful assistant.<|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "basic with user assistant user",
messages: []api.Message{
{Role: "user", Content: "What is the capital of France?"},
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
{Role: "user", Content: "Fantastic!"},
},
expected: `[gMASK]<sop><|user|>
What is the capital of France?<|assistant|>
The capital of France is Paris.<|user|>
Fantastic!<|assistant|>`,
},
{
name: "tools",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>`,
},
{
name: "tool calls",
messages: []api.Message{
{Role: "system", Content: "You are a helpful assistant with access to tools."},
{Role: "user", Content: "What is the weather like in Tokyo?"},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Tokyo, Japan",
"unit": "celsius",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_weather",
Arguments: api.ToolCallFunctionArguments{
"location": "Japan",
"unit": "fahrenheit",
},
},
},
},
},
{
Role: "tool",
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
ToolName: "get_weather",
},
{
Role: "tool",
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
ToolName: "get_weather",
},
{
Role: "assistant",
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
},
},
tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather in a given location",
Parameters: api.ToolFunctionParameters{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state, e.g. San Francisco, CA",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
expected: `[gMASK]<sop><|system|>
# Tools
You may call one or more functions to assist with the user query.
You are provided with function signatures within <tools></tools> XML tags:
<tools>
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
</tools>
For each function call, output the function name and arguments within the following XML format:
<tool_call>{function-name}
<arg_key>{arg-key-1}</arg_key>
<arg_value>{arg-value-1}</arg_value>
<arg_key>{arg-key-2}</arg_key>
<arg_value>{arg-value-2}</arg_value>
...
</tool_call><|system|>
You are a helpful assistant with access to tools.<|user|>
What is the weather like in Tokyo?<|assistant|>
<think></think>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Tokyo, Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>celsius</arg_value>
</tool_call>
<tool_call>get_weather
<arg_key>location</arg_key>
<arg_value>Japan</arg_value>
<arg_key>unit</arg_key>
<arg_value>fahrenheit</arg_value>
</tool_call><|observation|>
<tool_response>
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
</tool_response>
<tool_response>
{"temperature": 68, "weather": "sunny", "humidity": 75}
</tool_response><|assistant|>
<think></think>
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
},
{
name: "think true",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: true},
expected: `[gMASK]<sop><|user|>
Hello, how are you?<|assistant|>`,
},
{
name: "think false",
messages: []api.Message{
{Role: "user", Content: "Hello, how are you?"},
},
thinkValue: &api.ThinkValue{Value: false},
expected: `[gMASK]<sop><|user|>
Hello, how are you?/nothink<|assistant|>
<think></think>`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
rendered, err := GLM46Renderer(tt.messages, tt.tools, tt.thinkValue)
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
t.Logf("Got:\n%s", rendered)
t.Logf("Expected:\n%s", tt.expected)
}
})
}
}

109
model/renderers/gml46.go Normal file
View File

@@ -0,0 +1,109 @@
package renderers
import (
"encoding/json"
"fmt"
"strings"
"github.com/ollama/ollama/api"
)
func GLM46Renderer(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
var sb strings.Builder
sb.WriteString("[gMASK]<sop>")
var lastUserIndex int
for i, message := range messages {
if message.Role == "user" {
lastUserIndex = i
}
}
if len(tools) > 0 {
sb.WriteString("<|system|>\n")
sb.WriteString("# Tools\n\n")
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
sb.WriteString("<tools>\n")
for _, tool := range tools {
d, _ := json.Marshal(tool)
sb.WriteString(string(d) + "\n")
}
sb.WriteString("</tools>\n\n")
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
sb.WriteString("<tool_call>{function-name}\n")
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
sb.WriteString("...\n")
sb.WriteString("</tool_call>")
}
for i, message := range messages {
switch message.Role {
case "user":
sb.WriteString("<|user|>\n")
sb.WriteString(message.Content)
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
sb.WriteString("/nothink")
}
case "assistant":
sb.WriteString("<|assistant|>")
if i > lastUserIndex {
if message.Thinking != "" {
sb.WriteString("\n<think>" + message.Thinking + "</think>")
} else {
sb.WriteString("\n<think></think>")
}
}
if message.Content != "" {
sb.WriteString("\n" + message.Content)
}
if len(message.ToolCalls) > 0 {
for _, toolCall := range message.ToolCalls {
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
for key, value := range toolCall.Function.Arguments {
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
var valueStr string
if str, ok := value.(string); ok {
valueStr = str
} else {
jsonBytes, err := json.Marshal(value)
if err != nil {
valueStr = fmt.Sprintf("%v", value)
} else {
valueStr = string(jsonBytes)
}
}
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
}
sb.WriteString("</tool_call>")
}
}
case "tool":
if i == 0 || messages[i-1].Role != "tool" {
sb.WriteString("<|observation|>")
}
sb.WriteString("\n<tool_response>\n")
sb.WriteString(message.Content)
sb.WriteString("\n</tool_response>")
case "system":
sb.WriteString("<|system|>\n")
sb.WriteString(message.Content)
}
}
// Add generation prompt
sb.WriteString("<|assistant|>")
fmt.Println("thinkValue", thinkValue, thinkValue.Bool())
if thinkValue != nil && !thinkValue.Bool() {
sb.WriteString("\n<think></think>")
}
return sb.String(), nil
}

View File

@@ -20,6 +20,8 @@ func rendererForName(name string) rendererFunc {
switch name {
case "qwen3-coder":
return Qwen3CoderRenderer
case "glm-4.6":
return GLM46Renderer
default:
return nil
}

View File

@@ -79,6 +79,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason
// Metrics
@@ -94,8 +97,12 @@ type NewSequenceParams struct {
numKeep int
samplingParams *llama.SamplingParams
embedding bool
shift bool
truncate bool
}
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
@@ -119,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if len(inputs) > s.cache.numCtx {
discard := len(inputs) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
newInputs := inputs[:params.numKeep]
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
@@ -385,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
for i, input := range seq.inputs {
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
if len(seq.pendingInputs) == 0 {
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
break
}
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: req.Options.NumKeep,
samplingParams: &samplingParams,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}

View File

@@ -88,6 +88,9 @@ type Sequence struct {
// true if an embedding are to be returned instead of text generation
embeddingOnly bool
// shift if context window is exceeded
shift bool
doneReason llm.DoneReason
// Metrics
@@ -104,8 +107,12 @@ type NewSequenceParams struct {
numKeep int32
sampler sample.Sampler
embedding bool
shift bool
truncate bool
}
var errorInputTooLong = errors.New("the input length exceeds the context length")
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
s.ready.Wait()
@@ -125,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
if int32(len(inputs)) > s.cache.numCtx {
discard := int32(len(inputs)) - s.cache.numCtx
if !params.truncate {
return nil, errorInputTooLong
}
promptStart := params.numKeep + discard
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
@@ -176,6 +188,7 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
embeddingOnly: params.embedding,
stop: params.stop,
numKeep: params.numKeep,
shift: params.shift,
}, nil
}
@@ -321,9 +334,6 @@ type Server struct {
// TODO (jmorganca): make this n_batch
batchSize int
// Used to signal a hard failure during async processing which will panic the runner
hardErrCh chan error
// Simple counter used only for trace logging batches
batchID int
@@ -411,8 +421,6 @@ func (s *Server) run(ctx context.Context) {
select {
case <-ctx.Done():
return
case err := <-s.hardErrCh:
panic(err)
default:
var err error
nextBatch, err := s.forwardBatch(previousBatch)
@@ -522,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
break
}
if !seq.shift {
s.removeSequence(seqIdx, llm.DoneReasonLength)
nextBatch.seqs[seqIdx] = nil
break
}
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
if err != nil {
var reprocess *ErrReprocessInputs
@@ -663,9 +677,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
// don't sample prompt processing
if len(seq.inputs) != 0 {
if !s.cache.enabled {
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
s.mu.Unlock()
return
panic("caching disabled but unable to fit entire input in a batch")
}
continue
}
@@ -720,8 +732,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
return
panic("failed to sample token")
}
nextBatchTokens[i].Token = token
@@ -738,8 +749,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
if err != nil {
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
return
panic("failed to decode token")
}
seq.pendingResponses = append(seq.pendingResponses, piece)
@@ -841,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
numKeep: int32(req.Options.NumKeep),
sampler: sampler,
embedding: false,
shift: req.Shift,
truncate: req.Truncate,
})
if err != nil {
if errors.Is(err, errorInputTooLong) {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
return
}
@@ -1321,7 +1337,6 @@ func Execute(args []string) error {
server := &Server{
modelPath: *mpath,
status: llm.ServerStatusLaunched,
hardErrCh: make(chan error, 1),
}
server.cond = sync.NewCond(&server.mu)

View File

@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
var system []api.Message
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
}
}
if ctxLen > opts.NumCtx {
if truncate && ctxLen > opts.NumCtx {
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
break
} else {

View File

@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
cases := []struct {
name string
model Model
limit int
msgs []api.Message
name string
model Model
limit int
truncate bool
msgs []api.Message
expect
}{
{
name: "messages",
model: visionModel,
limit: 64,
name: "messages",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages",
model: visionModel,
limit: 1,
name: "truncate messages",
model: visionModel,
limit: 1,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with image",
model: visionModel,
limit: 64,
name: "truncate messages with image",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate messages with images",
model: visionModel,
limit: 64,
name: "truncate messages with images",
model: visionModel,
limit: 64,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with images",
model: visionModel,
limit: 2048,
name: "messages with images",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with image tag",
model: visionModel,
limit: 2048,
name: "message with image tag",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "messages with interleaved images",
model: visionModel,
limit: 2048,
name: "messages with interleaved images",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "truncate message with interleaved images",
model: visionModel,
limit: 1024,
name: "truncate message with interleaved images",
model: visionModel,
limit: 1024,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "user", Images: []api.ImageData{[]byte("something")}},
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "message with system prompt",
model: visionModel,
limit: 2048,
name: "message with system prompt",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "You're a test, Harry!"},
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "out of order system",
model: visionModel,
limit: 2048,
name: "out of order system",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
},
},
{
name: "multiple images same prompt",
model: visionModel,
limit: 2048,
name: "multiple images same prompt",
model: visionModel,
limit: 2048,
truncate: true,
msgs: []api.Message{
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
},
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
},
},
{
name: "no truncate with limit exceeded",
model: visionModel,
limit: 10,
truncate: false,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
}
for _, tt := range cases {
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
model := tt.model
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
think := false
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think})
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
if tt.error == nil && err != nil {
t.Fatal(err)
} else if tt.error != nil && err != tt.error {

View File

@@ -403,12 +403,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msgs = append(msgs, m.Messages...)
}
userMsg := api.Message{Role: "user", Content: req.Prompt}
for _, i := range images {
imgPrompt := ""
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
userMsg.Images = append(userMsg.Images, i.Data)
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
values.Messages = append(msgs, userMsg)
}
values.Think = req.Think != nil && req.Think.Bool()
@@ -429,12 +428,31 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s)
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// check that we're in the `api/chat`-like flow, and if so, generate the
// prompt the same way
// TEMP(drifkin): we should really just detect the chat-like flow and call
// the real chat handler, but doing this as a stopgap to get renderer
// support for generate
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
if req.Context != nil {
b.WriteString(prompt)
prompt = b.String()
}
} else {
// legacy flow
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
prompt = b.String()
prompt = b.String()
}
}
// If debug mode is enabled, return the rendered template instead of calling the model
@@ -470,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
var sb strings.Builder
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: req.Truncate == nil || *req.Truncate,
}, func(cr llm.CompletionResponse) {
res := api.GenerateResponse{
Model: req.Model,
@@ -535,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
}
}()
@@ -555,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
@@ -1620,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
return false
}
// errors are provided as a gin.H with an "error" field and
// an optional "status" field. For errors that are streamed
// before any content, we need to set the status code and
// content type for the error.
if h, ok := val.(gin.H); ok {
if e, ok := h["error"].(string); ok {
status, ok := h["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if !c.Writer.Written() {
c.Header("Content-Type", "application/json")
c.JSON(status, gin.H{"error": e})
} else {
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
slog.Error("streamResponse failed to encode json error", "error", err)
}
}
return false
}
}
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
@@ -1939,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
truncate := req.Truncate == nil || *req.Truncate
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -2016,10 +2071,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
Shift: req.Shift == nil || *req.Shift,
Truncate: truncate,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
@@ -2113,7 +2170,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
ch <- gin.H{"error": err.Error()}
var serr api.StatusError
if errors.As(err, &serr) {
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
} else {
ch <- gin.H{"error": err.Error()}
}
return
}
}
@@ -2127,7 +2189,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
@@ -2167,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
msg = "unexpected error format in response"
}
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
status, ok := t["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
c.JSON(status, gin.H{"error": msg})
return
default:
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})

View File

@@ -146,7 +146,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
DebugRenderOnly: true,
},
expectDebug: true,
expectTemplate: "[img-0]\n\nDescribe this image",
expectTemplate: "[img-0]Describe this image",
expectNumImages: 1,
},
{

View File

@@ -0,0 +1,313 @@
package server
import (
"bytes"
"encoding/json"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/discover"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
)
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
// when in chat-like flow (messages present, no suffix, no template)
func TestGenerateWithBuiltinRenderer(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer (qwen3-coder)
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
// Create a model with the qwen3-coder renderer
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
mock.CompletionResponse.Content = "Hi!"
t.Run("chat-like flow uses renderer", func(t *testing.T) {
// Test that when using messages (chat-like flow), the built-in renderer is used
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
// When messages are built internally from prompt, it should use the renderer
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
// Test that system messages work with the renderer
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
System: "You are a helpful coding assistant.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should contain the system message and use renderer format
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
}
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
}
})
t.Run("custom template bypasses renderer", func(t *testing.T) {
// Test that providing a custom template uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-renderer",
Prompt: "Write a hello world function",
Template: "{{ .Prompt }}",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when custom template is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should just be the raw prompt from the template
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
// Create a model with suffix support for the next test
w = createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-suffix-renderer",
From: "test-renderer",
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("suffix bypasses renderer", func(t *testing.T) {
// Test that providing a suffix uses the legacy flow
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix-renderer",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
// Should NOT use the renderer format when suffix is provided
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
}
// Should use the suffix template format
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
func TestGenerateWithDebugRenderOnly(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: getGpuFn,
getCpuFn: getCpuFn,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
return false
},
},
}
go s.sched.Run(t.Context())
// Create a model with a built-in renderer
_, digest := createBinFile(t, ggml.KV{
"general.architecture": "qwen3",
"qwen3.block_count": uint32(1),
"qwen3.context_length": uint32(8192),
"qwen3.embedding_length": uint32(4096),
"qwen3.attention.head_count": uint32(32),
"qwen3.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []*ggml.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})
w := createRequest(t, s.CreateHandler, api.CreateRequest{
Model: "test-debug-renderer",
Files: map[string]string{"file.gguf": digest},
Renderer: "qwen3-coder",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("debug_render_only with renderer", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-debug-renderer",
Prompt: "Write a hello world function",
System: "You are a coding assistant",
DebugRenderOnly: true,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var resp api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.DebugInfo == nil {
t.Fatalf("expected debug info, got nil")
}
// Verify that the rendered template uses the built-in renderer
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
}
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
}
})
}

View File

@@ -609,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error non-streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusServiceUnavailable,
Status: "Service Unavailable",
ErrorMessage: "model is overloaded",
}
}
stream := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusTooManyRequests,
Status: "Too Many Requests",
ErrorMessage: "rate limit exceeded",
}
}
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
})
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected status 429, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
func TestGenerate(t *testing.T) {
@@ -983,6 +1035,55 @@ func TestGenerate(t *testing.T) {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error non-streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusServiceUnavailable,
Status: "Service Unavailable",
ErrorMessage: "model is overloaded",
}
}
streamRequest := false
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &streamRequest,
})
if w.Code != http.StatusServiceUnavailable {
t.Errorf("expected status 503, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("status error streaming", func(t *testing.T) {
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
return api.StatusError{
StatusCode: http.StatusTooManyRequests,
Status: "Too Many Requests",
ErrorMessage: "rate limit exceeded",
}
}
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusTooManyRequests {
t.Errorf("expected status 429, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}
func TestChatWithPromptEndingInThinkTag(t *testing.T) {

View File

@@ -21,6 +21,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/types/model"
)
@@ -645,27 +646,35 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
totalMemoryBefore += gpu.TotalMemory
freeMemoryBefore += gpu.FreeMemory
}
totalMemoryNow := totalMemoryBefore
freeMemoryNow := freeMemoryBefore
go func() {
expiresAt := start.Add(5 * time.Second) // typical convergence is 0.5-1.5s
// typical convergence is 0.5-1.5s - If it takes more than 5 seconds to discover and converge, let the scheduler estimate VRAM usage
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
ticker := time.NewTicker(250 * time.Millisecond)
defer ticker.Stop()
for {
<-ticker.C
if time.Now().After(expiresAt) {
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
finished <- struct{}{}
}
// Query GPUs, look for free to go back up
gpusNow := s.getGpuFn(context.Background(), runners)
var totalMemoryNow, freeMemoryNow uint64
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
freeMemoryNow += gpu.FreeMemory
}
// If we're within ~80% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
select {
case <-ticker.C:
// Query GPUs, look for free to go back up
gpusNow := s.getGpuFn(ctx, runners)
totalMemoryNow = 0
freeMemoryNow = 0
for _, gpu := range gpusNow {
totalMemoryNow += gpu.TotalMemory
freeMemoryNow += gpu.FreeMemory
}
logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
// If we're within ~75% of the estimated memory usage recovered, bail out
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}
case <-ctx.Done():
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
finished <- struct{}{}
return
}