Compare commits
25 Commits
jmorganca/
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38ed7c7a4f | ||
|
|
9ff8e5a64d | ||
|
|
6544e14735 | ||
|
|
5db8a818a1 | ||
|
|
6db8da9958 | ||
|
|
0c68ec8d6a | ||
|
|
70d9e363e1 | ||
|
|
1a2feb2a97 | ||
|
|
aab2190420 | ||
|
|
629db9dc43 | ||
|
|
e0cd511661 | ||
|
|
207332078f | ||
|
|
93085127f4 | ||
|
|
c00fa9cc2b | ||
|
|
df411c4b02 | ||
|
|
3d32249c74 | ||
|
|
d681cd7c29 | ||
|
|
47298fce39 | ||
|
|
4a48937ef1 | ||
|
|
967a82f52f | ||
|
|
bbbc73d637 | ||
|
|
15e3611d3d | ||
|
|
77060d462c | ||
|
|
1b91d4dda1 | ||
|
|
7d965258ce |
@@ -85,6 +85,19 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, slices.Collect(splitDim(t, 1,
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
@@ -97,17 +110,28 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name += ".weight"
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
} else if strings.Contains(name, "ffn_gate_up_exps") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||
}, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -158,9 +182,21 @@ func (m *gptossModel) Replacements() []string {
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
slices []tensor.Slice
|
||||
|
||||
blocks, scales Tensor
|
||||
}
|
||||
|
||||
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
|
||||
slice[dim] = tensor.S(start, end, step)
|
||||
return &mxfp4{
|
||||
slices: slice,
|
||||
blocks: m.blocks,
|
||||
scales: m.scales,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||
@@ -204,6 +240,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(m.slices) > 0 {
|
||||
out, err = out.Slice(m.slices...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
dim int
|
||||
slices []tensor.Slice
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
slice := split.slices
|
||||
if len(slice) == 0 {
|
||||
slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
}
|
||||
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
|
||||
@@ -408,7 +408,7 @@ func (r *bootstrapRunner) HasExited() bool {
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo {
|
||||
// TODO DRY out with llm/server.go
|
||||
slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
|
||||
10
docs/gpu.md
10
docs/gpu.md
@@ -51,11 +51,11 @@ sudo modprobe nvidia_uvm`
|
||||
Ollama supports the following AMD GPUs:
|
||||
|
||||
### Linux Support
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` |
|
||||
|
||||
### Windows Support
|
||||
With ROCm v6.2, the following GPUs are supported on Windows.
|
||||
|
||||
@@ -243,7 +243,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"llama4",
|
||||
"mllama",
|
||||
|
||||
@@ -13,13 +13,13 @@ management libraries for more accurate VRAM usage reporting if available.
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 3 +-
|
||||
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 172 +++++++++++
|
||||
8 files changed, 718 insertions(+), 1 deletion(-)
|
||||
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
|
||||
8 files changed, 755 insertions(+), 1 deletion(-)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 0a2dae26..a6bf3378 100644
|
||||
index 0a2dae26a..a6bf33785 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -169,6 +169,15 @@ extern "C" {
|
||||
@@ -39,7 +39,7 @@ index 0a2dae26..a6bf3378 100644
|
||||
|
||||
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||
index 33b3a15f..86191ef2 100644
|
||||
index 33b3a15f0..86191ef2c 100644
|
||||
--- a/ggml/src/CMakeLists.txt
|
||||
+++ b/ggml/src/CMakeLists.txt
|
||||
@@ -206,6 +206,8 @@ add_library(ggml-base
|
||||
@@ -52,7 +52,7 @@ index 33b3a15f..86191ef2 100644
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 531d6e27..3fa3a057 100644
|
||||
index 531d6e272..3fa3a0575 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -184,7 +184,7 @@ index 531d6e27..3fa3a057 100644
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 06f9e7c1..eb8f66cb 100644
|
||||
index 06f9e7c1e..eb8f66cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -5,6 +5,9 @@
|
||||
@@ -206,7 +206,7 @@ index 06f9e7c1..eb8f66cb 100644
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index 86a1ebf6..9fc9fbfc 100644
|
||||
index 86a1ebf62..9fc9fbfcf 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -635,6 +635,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||
@@ -225,7 +225,7 @@ index 86a1ebf6..9fc9fbfc 100644
|
||||
}
|
||||
#endif
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
index 08ab4fc9..17999a61 100644
|
||||
index 08ab4fc91..17999a616 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
@@ -247,7 +247,7 @@ index 08ab4fc9..17999a61 100644
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 00000000..8ef19b8c
|
||||
index 000000000..8ef19b8cf
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,449 @@
|
||||
@@ -703,10 +703,10 @@ index 00000000..8ef19b8c
|
||||
\ No newline at end of file
|
||||
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
|
||||
new file mode 100644
|
||||
index 00000000..aa05e9dc
|
||||
index 000000000..c9073cef0
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_nvml.cpp
|
||||
@@ -0,0 +1,172 @@
|
||||
@@ -0,0 +1,209 @@
|
||||
+// NVIDIA Management Library (NVML)
|
||||
+//
|
||||
+// https://developer.nvidia.com/management-library-nvml
|
||||
@@ -721,6 +721,7 @@ index 00000000..aa05e9dc
|
||||
+#include "ggml-impl.h"
|
||||
+#include <filesystem>
|
||||
+#include <mutex>
|
||||
+#include <array>
|
||||
+
|
||||
+#ifdef _WIN32
|
||||
+# define WIN32_LEAN_AND_MEAN
|
||||
@@ -787,6 +788,7 @@ index 00000000..aa05e9dc
|
||||
+ nvmlReturn_t (*nvmlShutdown)(void);
|
||||
+ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
|
||||
+ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
+ const char * (*nvmlErrorString)(nvmlReturn_t result);
|
||||
+} nvml { NULL, NULL, NULL, NULL, NULL };
|
||||
+static std::mutex ggml_nvml_lock;
|
||||
+
|
||||
@@ -824,7 +826,8 @@ index 00000000..aa05e9dc
|
||||
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
|
||||
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
|
||||
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
@@ -833,11 +836,45 @@ index 00000000..aa05e9dc
|
||||
+
|
||||
+ SetErrorMode(old_mode);
|
||||
+
|
||||
+ nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
+ return status;
|
||||
+ }
|
||||
+#else
|
||||
+ // Not currently wired up on Linux
|
||||
+ return NVML_ERROR_NOT_SUPPORTED;
|
||||
+ constexpr std::array<const char*, 2> libPaths = {
|
||||
+ "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
|
||||
+ "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
|
||||
+ };
|
||||
+ for (const char* path : libPaths) {
|
||||
+ nvml.handle = dlopen(path, RTLD_LAZY);
|
||||
+ if (nvml.handle) break;
|
||||
+ }
|
||||
+ if (nvml.handle == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
|
||||
+ return NVML_ERROR_NOT_FOUND;
|
||||
+ }
|
||||
+ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
|
||||
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
|
||||
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
|
||||
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
|
||||
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
|
||||
+ dlclose(nvml.handle);
|
||||
+ nvml.handle = NULL;
|
||||
+ return NVML_ERROR_NOT_FOUND;
|
||||
+ }
|
||||
+ nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ dlclose(nvml.handle);
|
||||
+ nvml.handle = NULL;
|
||||
+ return status;
|
||||
+ }
|
||||
+#endif
|
||||
+ int status = nvml.nvmlInit_v2();
|
||||
+ return NVML_SUCCESS;
|
||||
+}
|
||||
+
|
||||
@@ -849,14 +886,14 @@ index 00000000..aa05e9dc
|
||||
+ }
|
||||
+ nvmlReturn_enum status = nvml.nvmlShutdown();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
|
||||
+ GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ }
|
||||
+#ifdef _WIN32
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
+#else
|
||||
+ // Not currently wired up on Linux
|
||||
+ dlclose(nvml.handle);
|
||||
+#endif
|
||||
+ nvml.handle = NULL;
|
||||
+}
|
||||
+
|
||||
+int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
|
||||
|
||||
@@ -1488,7 +1488,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
serverReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
if err != nil && errors.Is(err, context.Canceled) {
|
||||
// client closed connection
|
||||
return err
|
||||
} else if err != nil {
|
||||
slog.Error("post predict", "error", err)
|
||||
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
@@ -1500,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return api.StatusError{StatusCode: res.StatusCode, Status: res.Status, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
|
||||
return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(res.Body)
|
||||
|
||||
51
ml/backend/ggml/ggml/src/mem_nvml.cpp
vendored
51
ml/backend/ggml/ggml/src/mem_nvml.cpp
vendored
@@ -12,6 +12,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <array>
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
@@ -78,6 +79,7 @@ struct {
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
const char * (*nvmlErrorString)(nvmlReturn_t result);
|
||||
} nvml { NULL, NULL, NULL, NULL, NULL };
|
||||
static std::mutex ggml_nvml_lock;
|
||||
|
||||
@@ -115,7 +117,8 @@ int ggml_nvml_init() {
|
||||
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
|
||||
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
|
||||
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
|
||||
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
@@ -124,11 +127,45 @@ int ggml_nvml_init() {
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
return status;
|
||||
}
|
||||
#else
|
||||
// Not currently wired up on Linux
|
||||
return NVML_ERROR_NOT_SUPPORTED;
|
||||
constexpr std::array<const char*, 2> libPaths = {
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
|
||||
"libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
|
||||
};
|
||||
for (const char* path : libPaths) {
|
||||
nvml.handle = dlopen(path, RTLD_LAZY);
|
||||
if (nvml.handle) break;
|
||||
}
|
||||
if (nvml.handle == NULL) {
|
||||
GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
|
||||
return NVML_ERROR_NOT_FOUND;
|
||||
}
|
||||
nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
|
||||
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
|
||||
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
|
||||
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
|
||||
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
|
||||
dlclose(nvml.handle);
|
||||
nvml.handle = NULL;
|
||||
return NVML_ERROR_NOT_FOUND;
|
||||
}
|
||||
nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
dlclose(nvml.handle);
|
||||
nvml.handle = NULL;
|
||||
return status;
|
||||
}
|
||||
#endif
|
||||
int status = nvml.nvmlInit_v2();
|
||||
return NVML_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -140,14 +177,14 @@ void ggml_nvml_release() {
|
||||
}
|
||||
nvmlReturn_enum status = nvml.nvmlShutdown();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
|
||||
GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
}
|
||||
#ifdef _WIN32
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
#else
|
||||
// Not currently wired up on Linux
|
||||
dlclose(nvml.handle);
|
||||
#endif
|
||||
nvml.handle = NULL;
|
||||
}
|
||||
|
||||
int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
|
||||
|
||||
@@ -251,7 +251,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
@@ -266,7 +266,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
@@ -276,7 +276,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -73,7 +73,7 @@ func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []im
|
||||
for i, res := range possibleResolutions {
|
||||
scaleW := float64(res.X) / float64(w)
|
||||
scaleH := float64(res.Y) / float64(h)
|
||||
scale := math.Min(scaleW, scaleH)
|
||||
scale := min(scaleW, scaleH)
|
||||
|
||||
scales[i] = scale
|
||||
}
|
||||
@@ -124,11 +124,11 @@ func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Poi
|
||||
if scaleW < scaleH {
|
||||
newRes = image.Point{
|
||||
targetRes.X,
|
||||
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
int(min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
}
|
||||
} else {
|
||||
newRes = image.Point{
|
||||
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
int(min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
targetRes.Y,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (p ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point) image.Poi
|
||||
tw := min(max(imageSize.X, p.imageSize), canvasSize.X)
|
||||
th := min(max(imageSize.Y, p.imageSize), canvasSize.Y)
|
||||
|
||||
r := math.Min(
|
||||
r := min(
|
||||
float64(tw)/float64(imageSize.X),
|
||||
float64(th)/float64(imageSize.Y),
|
||||
)
|
||||
@@ -89,10 +89,10 @@ func (p ImageProcessor) optimalTiledCanvas(imageSize image.Point) image.Point {
|
||||
if minUpscale == 0 {
|
||||
minUpscale = s
|
||||
} else {
|
||||
minUpscale = math.Min(minUpscale, s)
|
||||
minUpscale = min(minUpscale, s)
|
||||
}
|
||||
} else {
|
||||
maxDownscale = math.Max(maxDownscale, s)
|
||||
maxDownscale = max(maxDownscale, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
219
model/parsers/glm46.go
Normal file
219
model/parsers/glm46.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
const (
|
||||
glm46CollectingContent glm46ParserState = iota
|
||||
CollectingThinkingContent
|
||||
CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingCloseTag = "</think>"
|
||||
)
|
||||
|
||||
// TODO(gguo): add a field for isThinking
|
||||
type GLM46Parser struct {
|
||||
state qwenParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(gguo): changes this to reference an objects param
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
p.tools = tools
|
||||
// p.state = p.initialState()
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseJSONToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
sb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
sb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func emitContentBeforeTag(p *GLM46Parser, events []glm46Event, tag string) []glm46Event {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return events
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||
p.state = glm46CollectingToolContent
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 { // why does qwen3coder not have this here
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case CollectingToolContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
events = append(events, glm46EventRawToolCall{raw: before})
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = glm46CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
return events, false
|
||||
}
|
||||
case glm46CollectingThinkingContent: // so we want to hip the unambiguous stuff
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: before})
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = glm46CollectingContent
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var toolCallFunction api.ToolCallFunction
|
||||
if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
toolCall := api.ToolCall{}
|
||||
toolCall.Function = toolCallFunction
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
@@ -21,6 +21,9 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "glm-4.6":
|
||||
parser := &GLM46Parser{}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
|
||||
239
model/renderers/glm46_test.go
Normal file
239
model/renderers/glm46_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo, Japan",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Japan",
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := GLM46Renderer(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
109
model/renderers/gml46.go
Normal file
109
model/renderers/gml46.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func GLM46Renderer(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
fmt.Println("thinkValue", thinkValue, thinkValue.Bool())
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -20,6 +20,8 @@ func rendererForName(name string) rendererFunc {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
return Qwen3CoderRenderer
|
||||
case "glm-4.6":
|
||||
return GLM46Renderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -567,18 +567,24 @@ func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
var effort string
|
||||
|
||||
if r.Reasoning != nil {
|
||||
if !slices.Contains([]string{"high", "medium", "low", "none"}, r.Reasoning.Effort) {
|
||||
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", r.Reasoning.Effort)
|
||||
effort = r.Reasoning.Effort
|
||||
} else if r.ReasoningEffort != nil {
|
||||
effort = *r.ReasoningEffort
|
||||
}
|
||||
|
||||
if effort != "" {
|
||||
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
|
||||
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
|
||||
}
|
||||
|
||||
if r.Reasoning.Effort == "none" {
|
||||
if effort == "none" {
|
||||
think = &api.ThinkValue{Value: false}
|
||||
} else {
|
||||
think = &api.ThinkValue{Value: r.Reasoning.Effort}
|
||||
think = &api.ThinkValue{Value: effort}
|
||||
}
|
||||
} else if r.ReasoningEffort != nil {
|
||||
think = &api.ThinkValue{Value: *r.ReasoningEffort}
|
||||
}
|
||||
|
||||
return &api.ChatRequest{
|
||||
|
||||
@@ -85,10 +85,10 @@ type Sequence struct {
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
processingDuration time.Duration
|
||||
generationDuration time.Duration
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -106,8 +106,6 @@ var errorInputTooLong = errors.New("the input length exceeds the context length"
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -153,18 +151,17 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -454,8 +451,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
t := time.Now()
|
||||
if err := s.lc.Decode(batch); err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
@@ -475,9 +472,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numDecoded += 1
|
||||
if seq.numDecoded == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
s.lc.Synchronize()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
} else {
|
||||
seq.processingDuration += time.Since(t)
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -668,9 +668,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.generationDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -94,10 +94,11 @@ type Sequence struct {
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
startedAt, lastUpdatedAt time.Time
|
||||
processingDuration time.Duration
|
||||
samplingDuration time.Duration
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -115,8 +116,6 @@ var errorInputTooLong = errors.New("the input length exceeds the context length"
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -176,21 +175,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shift: params.shift,
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shift: params.shift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -336,9 +334,6 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
@@ -421,25 +416,25 @@ func (s *Server) run(ctx context.Context) {
|
||||
|
||||
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||
|
||||
var activeBatch batchState
|
||||
var previousBatch batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
var err error
|
||||
activeBatch, err = s.forwardBatch(activeBatch)
|
||||
nextBatch, err := s.forwardBatch(previousBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
go s.computeBatch(nextBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
s.computeBatch(nextBatch)
|
||||
}
|
||||
|
||||
previousBatch = nextBatch
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -581,6 +576,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
for i := range nextBatch.seqs {
|
||||
if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
|
||||
nextBatch.seqs[i].startedAt = startedAt
|
||||
}
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
@@ -675,9 +677,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
s.mu.Unlock()
|
||||
return
|
||||
panic("caching disabled but unable to fit entire input in a batch")
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -701,6 +701,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
activeBatch.modelOutput)
|
||||
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
t := time.Now()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
@@ -713,8 +714,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -729,8 +732,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
panic("failed to sample token")
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
@@ -747,8 +749,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
@@ -793,6 +794,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
samplingDuration := time.Since(t)
|
||||
for i, seq := range s.seqs {
|
||||
if seq != nil && nextBatchTokens[i] != nil {
|
||||
s.seqs[i].samplingDuration += samplingDuration
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -912,9 +920,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1329,7 +1337,6 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
357
server/routes.go
357
server/routes.go
@@ -332,14 +332,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if req.Think != nil {
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
} else {
|
||||
// add thinking if the model supports it
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -401,12 +403,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
msgs = append(msgs, m.Messages...)
|
||||
}
|
||||
|
||||
userMsg := api.Message{Role: "user", Content: req.Prompt}
|
||||
for _, i := range images {
|
||||
imgPrompt := ""
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
|
||||
userMsg.Images = append(userMsg.Images, i.Data)
|
||||
}
|
||||
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
values.Messages = append(msgs, userMsg)
|
||||
}
|
||||
|
||||
values.Think = req.Think != nil && req.Think.Bool()
|
||||
@@ -427,12 +428,31 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
b.WriteString(s)
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(&b, values); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// check that we're in the `api/chat`-like flow, and if so, generate the
|
||||
// prompt the same way
|
||||
// TEMP(drifkin): we should really just detect the chat-like flow and call
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
||||
if req.Context != nil {
|
||||
b.WriteString(prompt)
|
||||
prompt = b.String()
|
||||
}
|
||||
} else {
|
||||
// legacy flow
|
||||
if err := tmpl.Execute(&b, values); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt = b.String()
|
||||
prompt = b.String()
|
||||
}
|
||||
}
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
@@ -535,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- err
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
||||
} else {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -549,20 +574,18 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
sbThinking.WriteString(t.Thinking)
|
||||
sbContent.WriteString(t.Response)
|
||||
r = t
|
||||
case api.StatusError:
|
||||
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
|
||||
return
|
||||
case error:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
|
||||
return
|
||||
// TODO (jmorganca): remove use of gin.H here and instead expect
|
||||
// api.StatusError to be send in the channel
|
||||
case gin.H:
|
||||
if msg, ok := t["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
|
||||
status, ok := t["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
@@ -1627,16 +1650,28 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
return false
|
||||
}
|
||||
|
||||
if statusError, ok := val.(api.StatusError); ok {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithStatusJSON(statusError.StatusCode, gin.H{"error": statusError.ErrorMessage})
|
||||
return false
|
||||
}
|
||||
// errors are provided as a gin.H with an "error" field and
|
||||
// an optional "status" field. For errors that are streamed
|
||||
// before any content, we need to set the status code and
|
||||
// content type for the error.
|
||||
if h, ok := val.(gin.H); ok {
|
||||
if e, ok := h["error"].(string); ok {
|
||||
status, ok := h["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
if err, ok := val.(error); ok {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return false
|
||||
if !c.Writer.Written() {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(status, gin.H{"error": e})
|
||||
} else {
|
||||
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
|
||||
slog.Error("streamResponse failed to encode json error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
@@ -1898,14 +1933,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if req.Think != nil {
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
} else {
|
||||
// add thinking if the model supports it
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -2001,90 +2038,174 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
type structuredOutputsState int
|
||||
const (
|
||||
structuredOutputsState_None structuredOutputsState = iota
|
||||
structuredOutputsState_ReadyToApply
|
||||
structuredOutputsState_Applying
|
||||
)
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
structuredOutputsState := structuredOutputsState_None
|
||||
|
||||
for {
|
||||
var tb strings.Builder
|
||||
|
||||
currentFormat := req.Format
|
||||
// structured outputs via double request is enabled when:
|
||||
// 1. the model supports the thinking capability and
|
||||
// 2. it uses a built-in parser or our generic thinking parser
|
||||
|
||||
// Note that the current approach does not work for (potential future)
|
||||
// non-thinking models that emit anything before actual content. This
|
||||
// current approach uses the transition from parsed thinking content to
|
||||
// parsed non-thinking content as the signal to turn constraining on
|
||||
|
||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||
currentFormat = nil
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
// sets up new context given parent context per request
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(ctx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
|
||||
tb.WriteString(thinking)
|
||||
// we are now receiving content from the model - we should start applying structured outputs
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Thinking = thinkingContent
|
||||
tb.WriteString(thinkingContent)
|
||||
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
|
||||
// to avoid leaking mixed tokens like "</think>Hello"
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
res.Message.Content = ""
|
||||
ch <- res
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
})
|
||||
if err != nil {
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
|
||||
// only ignores error if it's a context cancellation due to setting structured outputs
|
||||
} else {
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
||||
} else {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- err
|
||||
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply {
|
||||
structuredOutputsState = structuredOutputsState_Applying
|
||||
msg := api.Message{
|
||||
Role: "assistant",
|
||||
Thinking: tb.String(),
|
||||
}
|
||||
|
||||
msgs = append(msgs, msg)
|
||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
// force constraining by terminating thinking header, the parser is already at this state
|
||||
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
|
||||
// model continue thinking or ending thinking and outputting the final message.
|
||||
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
|
||||
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
|
||||
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2102,20 +2223,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls = append(toolCalls, t.Message.ToolCalls...)
|
||||
}
|
||||
case api.StatusError:
|
||||
c.JSON(t.StatusCode, gin.H{"error": t.ErrorMessage})
|
||||
return
|
||||
case error:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": t.Error()})
|
||||
return
|
||||
// TODO (jmorganca): remove use of gin.H here and instead expect
|
||||
// api.StatusError to be send in the channel
|
||||
case gin.H:
|
||||
if msg, ok := t["error"].(string); ok {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
return
|
||||
msg, ok := t["error"].(string)
|
||||
if !ok {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
|
||||
status, ok := t["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[img-0]\n\nDescribe this image",
|
||||
expectTemplate: "[img-0]Describe this image",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
|
||||
313
server/routes_generate_renderer_test.go
Normal file
313
server/routes_generate_renderer_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
|
||||
// when in chat-like flow (messages present, no suffix, no template)
|
||||
func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer (qwen3-coder)
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Create a model with the qwen3-coder renderer
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
|
||||
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
||||
// Test that when using messages (chat-like flow), the built-in renderer is used
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
||||
// When messages are built internally from prompt, it should use the renderer
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
|
||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
|
||||
// Test that system messages work with the renderer
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a helpful coding assistant.",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should contain the system message and use renderer format
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
|
||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
|
||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom template bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a custom template uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when custom template is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should just be the raw prompt from the template
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a model with suffix support for the next test
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
From: "test-renderer",
|
||||
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||
{{- else }}{{ .Prompt }}
|
||||
{{- end }}`,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("suffix bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a suffix uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
Prompt: "def add(",
|
||||
Suffix: " return c",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when suffix is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should use the suffix template format
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
|
||||
func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("debug_render_only with renderer", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a coding assistant",
|
||||
DebugRenderOnly: true,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.DebugInfo == nil {
|
||||
t.Fatalf("expected debug info, got nil")
|
||||
}
|
||||
|
||||
// Verify that the rendered template uses the built-in renderer
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
|
||||
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
|
||||
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
|
||||
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -158,11 +158,26 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"\"test\" does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model can't think but think set false", func(t *testing.T) {
|
||||
think := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -1292,4 +1307,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart non-stream", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
|
||||
}
|
||||
|
||||
if resp.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected response to be done")
|
||||
}
|
||||
|
||||
if resp.DoneReason != "stop" {
|
||||
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart streaming", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var events []api.ChatResponse
|
||||
for {
|
||||
var event api.ChatResponse
|
||||
if err := decoder.Decode(&event); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
events = append(events, event)
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("expected at least two streaming events, got %d", len(events))
|
||||
}
|
||||
|
||||
first := events[0]
|
||||
if first.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
|
||||
}
|
||||
|
||||
if first.Message.Content != "" {
|
||||
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
|
||||
}
|
||||
|
||||
if first.Done {
|
||||
t.Error("expected first event to be non-terminal")
|
||||
}
|
||||
|
||||
last := events[len(events)-1]
|
||||
if last.Message.Thinking != "" {
|
||||
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
|
||||
}
|
||||
|
||||
if last.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
|
||||
}
|
||||
|
||||
if !last.Done {
|
||||
t.Error("expected final event to be done")
|
||||
}
|
||||
|
||||
if last.DoneReason != "stop" {
|
||||
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -229,8 +230,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
// Shouildn't happen
|
||||
slog.Error("runner to expire was nil!")
|
||||
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
|
||||
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
|
||||
slog.Debug("runner to expire was nil, retrying")
|
||||
continue
|
||||
}
|
||||
// Trigger an expiration to unload once it's done
|
||||
@@ -644,27 +646,35 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
|
||||
totalMemoryBefore += gpu.TotalMemory
|
||||
freeMemoryBefore += gpu.FreeMemory
|
||||
}
|
||||
totalMemoryNow := totalMemoryBefore
|
||||
freeMemoryNow := freeMemoryBefore
|
||||
|
||||
go func() {
|
||||
expiresAt := start.Add(5 * time.Second) // typical convergence is 0.5-1.5s
|
||||
// typical convergence is 0.5-1.5s - If it takes more than 5 seconds to discover and converge, let the scheduler estimate VRAM usage
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
<-ticker.C
|
||||
if time.Now().After(expiresAt) {
|
||||
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
}
|
||||
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := s.getGpuFn(context.Background(), runners)
|
||||
var totalMemoryNow, freeMemoryNow uint64
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
// If we're within ~80% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := s.getGpuFn(ctx, runners)
|
||||
totalMemoryNow = 0
|
||||
freeMemoryNow = 0
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
|
||||
// If we're within ~75% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user