Compare commits

..

23 Commits

Author SHA1 Message Date
Parth Sareen
77060d462c routes: structured outputs for gpt-oss (#12460) 2025-10-08 19:13:38 -07:00
Patrick Devine
1b91d4dda1 openai: change the reasonin_effort field to also take none 2025-10-08 18:21:01 -07:00
Jeffrey Morgan
7d965258ce Revert "add truncate and shift parameters (#12519)" (#12545)
This reverts commit 6a62b894c7.
2025-10-08 17:57:57 -07:00
Jeffrey Morgan
6a62b894c7 add truncate and shift parameters (#12519) 2025-10-08 17:05:05 -07:00
Patrick Devine
90d429f5a8 thinking: turn on thinking mode for all reasoning models (#12533) 2025-10-08 16:50:13 -07:00
Jesse Gross
1fc35f1260 kvcache: Clean up sliding window state with independent batches
Sliding windows models (e.g. gpt-oss, gemma3) remove tokens that
are out of the cache's window each time we start a new forward pass.

The cache storage needs to handle the window size for each sequence
plus the batch size, since the batch needs to attend to the full
window size. This means that we have greater than a window size
stored while processing the batch.

When the next batch comes, we are currently only looking at the
sequences in the incoming batch to slide the window forward.
However, we also need to clean up the other sequences that might
be occupying space in the batch processing buffer to ensure each
sequence is only using its window size of storage. Failure to do
this can result in "no kv cache slot found" errors.

Fixes: #10127
2025-10-08 16:43:14 -07:00
Jesse Gross
aa45f7ce27 discover: Disable flash attention for Jetson Xavier (CC 7.2)
GGML picks the wrong kernel and these systems fail with:
Sep 28 22:25:39 xavier ollama[48999]: //ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu:437:
ERROR: CUDA kernel flash_attn_ext_f16 has no device code compatible with CUDA arch 720. ggml-cuda.cu
was compiled for: __CUDA_ARCH_LIST__

Fixes #12442
2025-10-08 09:56:15 -07:00
Daniel Hiltgen
4e5d862ec4 Integration test tuning (#12492)
Remove some flaky scenarios, and switch to chat for better reliability
2025-10-08 09:51:25 -07:00
Daniel Hiltgen
303be9304c docs: improve accuracy of LLM library docs (#12530) 2025-10-07 16:21:07 -07:00
Daniel Hiltgen
bd15eba4e4 Bring back escape valve for llm libraries and fix Jetpack6 crash (#12529)
* Bring back escape valve for llm libraries

If the new discovery logic picks the wrong library, this gives users the
ability to force a specific one using the same pattern as before. This
can also potentially speed up bootstrap discovery if one of the libraries
takes a long time to load and ultimately bind to no devices.  For example
unsupported AMD iGPUS can sometimes take a while to discover and rule out.

* Bypass extra discovery on jetpack systems

On at least Jetpack6, cuda_v12 appears to expose the iGPU, but crashes later on in
cublasInit so if we detect a Jetpack, short-circuit and use that variant.
2025-10-07 16:06:14 -07:00
Devon Rifkin
bc71278670 Merge pull request #12509 from ollama/drifkin/oai-compat-refactor
openai: refactor to split compat layer and middleware
2025-10-06 16:22:08 -07:00
Daniel Hiltgen
918231931c win: fix build script (#12513) 2025-10-06 14:46:45 -07:00
Daniel Hiltgen
04c1849878 discovery: prevent dup OLLAMA_LIBRARY_PATH (#12514)
This variable isn't currently documented or intended as something the user can
override, but if the user happens to set OLLAMA_LIBRARY_PATH we were doubling
this in the subprocess environment which will cause problems with the new
bootstrap discovery logic.
2025-10-06 14:36:44 -07:00
Devon Rifkin
2c2f4deaa9 openai: refactor to split compat layer and middleware
This makes the core openai compat layer independent of the middleware
that adapts it to our particular gin routes
2025-10-05 14:18:56 -07:00
Daniel Hiltgen
292767afb4 CI: fix win arm build (#12502)
Resolve subtle erroraction stickiness difference between x86 and arm builder setup
2025-10-04 11:46:45 -07:00
Daniel Hiltgen
ae5e0f0889 CI: replace clang compiler for windows (#12495) 2025-10-04 09:18:42 -07:00
Jesse Gross
19e6796eac llm: Support KV cache quantization with gpt-oss
With the new version of GGML in #12245, KV cache quantization
no longer causes a fallback to CPU.
2025-10-03 16:31:58 -07:00
Grace
33801c1597 Fixed Deepseek2 adding nil tensor error 2025-10-03 14:20:06 -07:00
Daniel Hiltgen
e4340667e3 Workaround broken NVIDIA iGPU free VRAM data (#12490)
The CUDA APIs for reporting free VRAM are useless on NVIDIA iGPU
systems as they only return the kernels actual free memory and ignore
buff/cache allocations which on a typical system will quickly fill up
most of the free system memory.  As a result, we incorrectly think
there's very little available for GPU allocations which is wrong.
2025-10-03 12:17:21 -07:00
Patrick Devine
2fa1e92a99 test: add template error test (#12489) 2025-10-03 12:05:34 -07:00
Daniel Hiltgen
07e36761c3 ci: place rocm windows in correct runner dir (#12487) 2025-10-03 07:28:40 -07:00
Daniel Hiltgen
c29fb007c0 CI: temporarily disable clang install (#12486)
This will likely yield builds that have problems with unicode characters
but at least we can start testing the release while we try to find an
alternate clang compiler for windows, or mingw ships a fixed version.
2025-10-02 20:31:18 -07:00
Daniel Hiltgen
730ed6e9e1 ci: fix windows build (#12485) 2025-10-02 19:16:01 -07:00
29 changed files with 2458 additions and 1582 deletions

View File

@@ -94,7 +94,7 @@ jobs:
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: ''
runner_dir: 'rocm'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -163,7 +163,7 @@ jobs:
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
cmake --build --parallel --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
rm -force dist\lib\ollama\rocm\rocblas\library\*gfx906*
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
- uses: actions/upload-artifact@v4
@@ -176,19 +176,19 @@ jobs:
matrix:
os: [windows]
arch: [amd64, arm64]
include:
- os: windows
arch: amd64
llvmarch: x86_64
- os: windows
arch: arm64
llvmarch: aarch64
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: [setup-environment]
env:
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
steps:
- name: Install AMD64 system dependencies
if: matrix.arch == 'amd64'
run: |
$ErrorActionPreference = "Stop"
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install ARM64 system dependencies
if: matrix.arch == 'arm64'
run: |
@@ -200,15 +200,29 @@ jobs:
choco install -y --no-progress git gzip
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Install clang and gcc-compat
run: |
$ErrorActionPreference = "Stop"
Set-ExecutionPolicy Bypass -Scope Process -Force
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-${{ matrix.llvmarch }}.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt.zip"
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt.zip -DestinationPath "C:\Program Files\"
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt*").path
echo "$installPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- name: Verify gcc is actually clang
run: |
$ErrorActionPreference='Continue'
$version=& gcc -v 2>&1
$version=$version -join "`n"
echo "gcc is $version"
if ($version -notmatch 'clang') {
echo "ERROR: GCC must be clang for proper utf16 handling"
exit 1
}
$ErrorActionPreference='Stop'
- run: |
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
- uses: actions/upload-artifact@v4

View File

@@ -936,7 +936,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
return nil
}
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
}
// MarshalJSON implements json.Marshaler

View File

@@ -2,11 +2,12 @@ package discover
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
"regexp"
"runtime"
"strconv"
"strings"
"github.com/ollama/ollama/format"
@@ -60,17 +61,14 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
DependencyPath: dev.LibraryPath,
DriverMajor: dev.DriverMajor,
DriverMinor: dev.DriverMinor,
ComputeMajor: dev.ComputeMajor,
ComputeMinor: dev.ComputeMinor,
}
if dev.Library == "CUDA" || dev.Library == "ROCm" {
info.MinimumMemory = 457 * format.MebiByte
}
if dev.Library == "ROCm" {
info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor)
if rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
}
} else {
info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor)
if dev.Library == "ROCm" && rocmDir != "" {
info.DependencyPath = append(info.DependencyPath, rocmDir)
}
resp = append(resp, info)
}
@@ -146,3 +144,35 @@ func GetSystemInfo() SystemInfo {
GPUs: gpus,
}
}
func cudaJetpack() string {
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
if CudaTegra != "" {
ver := strings.Split(CudaTegra, ".")
if len(ver) > 0 {
return "jetpack" + ver[0]
}
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
r := regexp.MustCompile(` R(\d+) `)
m := r.FindSubmatch(data)
if len(m) != 2 {
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
} else {
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
// https://developer.nvidia.com/embedded/jetpack-archive
switch l4t {
case 35:
return "jetpack5"
case 36:
return "jetpack6"
default:
// Newer Jetson systems use the SBSU runtime
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
}
}
}
}
}
return ""
}

View File

@@ -78,6 +78,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
slog.Info("discovering available GPUs...")
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
// For our initial discovery pass, we gather all the known GPUs through
// all the libraries that were detected. This pass may include GPUs that
@@ -86,6 +88,14 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
// times concurrently leading to memory contention
for dir := range libDirs {
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
}
}
if dir == "" {
dirs = []string{LibOllamaPath}
} else {
@@ -330,6 +340,9 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
}
}
// Apply any iGPU workarounds
iGPUWorkarounds(devices)
return devices
}
@@ -439,15 +452,18 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
cmd.Stderr = os.Stderr
}
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
pathNeeded := true
ollamaPathNeeded := true
extraDone := make([]bool, len(extraEnvs))
for i := range cmd.Env {
cmp := strings.SplitN(cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
ollamaPathNeeded = false
} else {
for j := range extraEnvs {
if extraDone[j] {
@@ -464,6 +480,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
if pathNeeded {
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
}
if ollamaPathNeeded {
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
}
for i := range extraDone {
if !extraDone[i] {
cmd.Env = append(cmd.Env, extraEnvs[i])
@@ -540,3 +559,32 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceIn
}
}
}
func iGPUWorkarounds(devices []ml.DeviceInfo) {
// short circuit if we have no iGPUs
anyiGPU := false
for i := range devices {
if devices[i].Integrated {
anyiGPU = true
break
}
}
if !anyiGPU {
return
}
memInfo, err := GetCPUMem()
if err != nil {
slog.Debug("failed to fetch system memory information for iGPU", "error", err)
return
}
for i := range devices {
if !devices[i].Integrated {
continue
}
// NVIDIA iGPUs return useless free VRAM data which ignores system buff/cache
if devices[i].Library == "CUDA" {
devices[i].FreeMemory = memInfo.FreeMemory
}
}
}

View File

@@ -37,9 +37,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
UnreliableFreeMemory bool
// GPU information
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
Compute string `json:"compute"` // Compute Capability or gfx
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
Name string `json:"name"` // user friendly name if available
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
ComputeMinor int `json:"compute_minor"`
// Driver Information - TODO no need to put this on each GPU
DriverMajor int `json:"driver_major,omitempty"`
@@ -173,7 +174,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
for _, gpu := range l {
supportsFA := gpu.Library == "cpu" ||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7) ||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
gpu.Library == "ROCm"
if !supportsFA {

View File

@@ -38,26 +38,14 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
## LLM libraries
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
In the server log, you will see a message that looks something like this (varies from release to release):
```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
```
Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library.
**Experimental LLM Library Override**
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
```shell
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
```
You can see what features your CPU has with the following.
```shell
cat /proc/cpuinfo| grep flags | head -1
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
```
## Installing older or pre-release versions on Linux

View File

@@ -870,11 +870,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
return true
}
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
// gpt-oss uses attention with sinks which does not support quantized cache types
slog.Warn("model only supports non-quantized cache types", "model", arch)
return false
}
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
}

View File

@@ -17,16 +17,21 @@ func TestBlueSky(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: blueSkyPrompt,
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, blueSkyExpected)
ChatTestHelper(ctx, t, req, blueSkyExpected)
}
func TestUnicode(t *testing.T) {
@@ -34,10 +39,15 @@ func TestUnicode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
req := api.ChatRequest{
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
Messages: []api.Message{
{
Role: "user",
Content: "天空为什么是蓝色的?", // Why is the sky blue?
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -57,9 +67,14 @@ func TestUnicode(t *testing.T) {
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
defer func() {
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}()
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
DoGenerate(ctx, t, client, req, []string{
DoChat(ctx, t, client, req, []string{
"散射", // scattering
"频率", // frequency
}, 120*time.Second, 120*time.Second)
@@ -69,9 +84,14 @@ func TestExtendedUnicodeOutput(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: "gemma2:2b",
Prompt: "Output some smily face emoji",
req := api.ChatRequest{
Model: "gemma2:2b",
Messages: []api.Message{
{
Role: "user",
Content: "Output some smily face emoji",
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -83,7 +103,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
}
func TestUnicodeModelDir(t *testing.T) {
@@ -108,14 +128,19 @@ func TestUnicodeModelDir(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.GenerateRequest{
Model: smol,
Prompt: blueSkyPrompt,
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
"seed": 123,
},
}
GenerateTestHelper(ctx, t, req, blueSkyExpected)
ChatTestHelper(ctx, t, req, blueSkyExpected)
}

View File

@@ -20,9 +20,9 @@ import (
)
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
func TestConcurrentGenerate(t *testing.T) {
func TestConcurrentChat(t *testing.T) {
// Assumes all requests have the same model
req, resp := GenerateRequests()
req, resp := ChatRequests()
numParallel := int(envconfig.NumParallel() + 1)
iterLimit := 3
@@ -57,7 +57,7 @@ func TestConcurrentGenerate(t *testing.T) {
slog.Info("Starting", "thread", i, "iter", j)
// On slower GPUs it can take a while to process the concurrent requests
// so we allow a much longer initial timeout
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
}
}(i)
}
@@ -163,7 +163,7 @@ chooseModels:
wg.Add(1)
go func(i int) {
defer wg.Done()
reqs, resps := GenerateRequests()
reqs, resps := ChatRequests()
for j := 0; j < 3; j++ {
if time.Now().Sub(started) > softTimeout {
slog.Info("exceeded soft timeout, winding down test")
@@ -171,8 +171,8 @@ chooseModels:
}
k := r.Int() % len(reqs)
reqs[k].Model = chosenModels[i]
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
DoGenerate(ctx, t, client, reqs[k], resps[k],
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
DoChat(ctx, t, client, reqs[k], resps[k],
120*time.Second, // Be extra patient for the model to load initially
10*time.Second, // Once results start streaming, fail if they stall
)

View File

@@ -21,9 +21,14 @@ func TestLongInputContext(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "Oh, dont speak to me of Austria. Perhaps I dont understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexanders loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I dont believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -36,7 +41,7 @@ func TestLongInputContext(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
}
func TestContextExhaustion(t *testing.T) {
@@ -48,9 +53,14 @@ func TestContextExhaustion(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: smol,
Prompt: "Write me a story in english with a lot of emojis",
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{
Role: "user",
Content: "Write me a story in english with a lot of emojis",
},
},
Stream: &stream,
Options: map[string]any{
"temperature": 0,
@@ -63,12 +73,12 @@ func TestContextExhaustion(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
}
// Send multiple generate requests with prior context and ensure the response is coherant and expected
func TestParallelGenerateWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
modelOverride := "gpt-oss:20b"
req, resp := GenerateRequests()
numParallel := 2
iterLimit := 2
@@ -155,7 +165,7 @@ func TestGenerateWithHistory(t *testing.T) {
// Send multiple chat requests with prior context and ensure the response is coherant and expected
func TestParallelChatWithHistory(t *testing.T) {
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
modelOverride := "gpt-oss:20b"
req, resp := ChatRequests()
numParallel := 2
iterLimit := 2

View File

@@ -15,7 +15,7 @@ import (
// First run of this scenario on a target system will take a long time to download
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
func TestLibraryModelsGenerate(t *testing.T) {
func TestLibraryModelsChat(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -43,9 +43,14 @@ func TestLibraryModelsGenerate(t *testing.T) {
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
}
}
req := api.GenerateRequest{
Model: model,
Prompt: blueSkyPrompt,
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0.1,
@@ -58,13 +63,13 @@ func TestLibraryModelsGenerate(t *testing.T) {
anyResp = []string{"select", "from"}
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
anyResp = []string{"yes", "no", "safe", "unsafe"}
} else if model == "openthinker" || model == "nexusraven" {
} else if model == "openthinker" {
anyResp = []string{"plugin", "im_sep", "components", "function call"}
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
req.Prompt = "def fibonacci():"
req.Messages[0].Content = "def fibonacci():"
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
}
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
})
}
}

View File

@@ -34,17 +34,22 @@ func TestVisionModels(t *testing.T) {
if err != nil {
t.Fatal(err)
}
req := api.GenerateRequest{
Model: v.model,
Prompt: "what does the text in this image say?",
req := api.ChatRequest{
Model: v.model,
Messages: []api.Message{
{
Role: "user",
Content: "what does the text in this image say?",
Images: []api.ImageData{
image,
},
},
},
Stream: &stream,
Options: map[string]any{
"seed": 42,
"temperature": 0.0,
},
Images: []api.ImageData{
image,
},
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
@@ -56,8 +61,15 @@ func TestVisionModels(t *testing.T) {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
if err != nil {
t.Fatalf("failed to load model %s: %s", req.Model, err)
}
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
// llava models on CPU can be quite slow to start
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
})
}
}

View File

@@ -19,7 +19,7 @@ import (
"github.com/ollama/ollama/format"
)
func TestModelsGenerate(t *testing.T) {
func TestModelsChat(t *testing.T) {
softTimeout, hardTimeout := getTimeouts(t)
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
@@ -66,15 +66,23 @@ func TestModelsGenerate(t *testing.T) {
}
}
// TODO - fiddle with context size
req := api.GenerateRequest{
Model: model,
Prompt: blueSkyPrompt,
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
},
}
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
})
}
}
@@ -128,8 +136,9 @@ func TestModelsEmbed(t *testing.T) {
}
}
req := api.EmbeddingRequest{
Model: model,
Prompt: "why is the sky blue?",
Model: model,
Prompt: "why is the sky blue?",
KeepAlive: &api.Duration{Duration: 10 * time.Second},
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
@@ -139,6 +148,10 @@ func TestModelsEmbed(t *testing.T) {
if err != nil {
t.Fatalf("embeddings call failed %s", err)
}
defer func() {
// best effort unload once we're done with the model
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
}()
if len(resp.Embedding) == 0 {
t.Errorf("zero length embedding response")
}

View File

@@ -173,9 +173,14 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
continue
}
req := api.GenerateRequest{
Model: model,
Prompt: tc.prompt,
req := api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: tc.prompt,
},
},
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
Options: map[string]interface{}{
"temperature": 0,
@@ -184,7 +189,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
},
}
atLeastOne := false
var resp api.GenerateResponse
var resp api.ChatResponse
stream := false
req.Stream = &stream
@@ -198,7 +203,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
)
defer cancel()
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error {
resp = rsp
return nil
})
@@ -214,13 +219,13 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
}
loaded = true
for _, expResp := range tc.anyResp {
if strings.Contains(strings.ToLower(resp.Response), expResp) {
if strings.Contains(strings.ToLower(resp.Message.Content), expResp) {
atLeastOne = true
break
}
}
if !atLeastOne {
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content)
}
models, err := client.ListRunning(ctx)
if err != nil {

View File

@@ -74,9 +74,14 @@ func TestQuantization(t *testing.T) {
}
stream := true
genReq := api.GenerateRequest{
Model: newName,
Prompt: blueSkyPrompt,
chatReq := api.ChatRequest{
Model: newName,
Messages: []api.Message{
{
Role: "user",
Content: blueSkyPrompt,
},
},
KeepAlive: &api.Duration{Duration: 3 * time.Second},
Options: map[string]any{
"seed": 42,
@@ -91,8 +96,8 @@ func TestQuantization(t *testing.T) {
reqCtx, reqCancel := context.WithCancel(ctx)
atLeastOne := false
var buf bytes.Buffer
genfn := func(response api.GenerateResponse) error {
buf.Write([]byte(response.Response))
chatfn := func(response api.ChatResponse) error {
buf.Write([]byte(response.Message.Content))
fullResp := strings.ToLower(buf.String())
for _, resp := range blueSkyExpected {
if strings.Contains(fullResp, resp) {
@@ -108,14 +113,14 @@ func TestQuantization(t *testing.T) {
done := make(chan int)
var genErr error
go func() {
genErr = client.Generate(reqCtx, &genReq, genfn)
genErr = client.Chat(reqCtx, &chatReq, chatfn)
done <- 0
}()
select {
case <-done:
if genErr != nil && !atLeastOne {
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
}
case <-ctx.Done():
t.Error("outer test context done while waiting for generate")

View File

@@ -15,6 +15,7 @@ import (
"net/http"
"net/url"
"os"
"os/exec"
"path/filepath"
"runtime"
"strconv"
@@ -24,7 +25,6 @@ import (
"time"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/format"
)
@@ -38,6 +38,7 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
"mistral-small3.2:latest",
@@ -46,6 +47,7 @@ var (
"qwen2.5-coder:latest",
"qwen2.5vl:3b",
"qwen3:0.6b", // dense
"qwen3:1.7b", // dense
"qwen3:30b", // MOE
"gemma3:1b",
"llama3.1:latest",
@@ -265,16 +267,16 @@ var (
"Explain the physics involved in them. Be breif in your reply",
"Explain the chemistry involved in them. Be breif in your reply",
"What are common myths related to them? Be brief in your reply",
"What are common fairytales related to them? Be brief in your reply",
"Can they form if there is no rain? Be breif in your reply",
"Can they form if there are no clouds? Be breif in your reply",
"Do they happen on other planets? Be brief in your reply",
}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "end", "gold", "fortune", "blessing", "prosperity", "magic", "shower", "sky", "shimmer", "light", "storm", "sunny"}
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"}
)
func init() {
lifecycle.InitLogging()
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
slog.SetDefault(logger)
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
if custom != "" {
slog.Info("setting default test model to " + custom)
@@ -335,6 +337,7 @@ func GetTestEndpoint() (*api.Client, string) {
var serverMutex sync.Mutex
var serverReady bool
var serverLogFile string
func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
// Make sure the server has been built
@@ -361,8 +364,9 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
t.Setenv("OLLAMA_HOST", ollamaHost)
}
logDir := t.TempDir()
slog.Info("starting server", "url", ollamaHost)
done, err := lifecycle.SpawnServer(ctx, "../ollama")
done, err := SpawnServer(ctx, "../ollama", logDir)
if err != nil {
return fmt.Errorf("failed to start server: %w", err)
}
@@ -385,6 +389,36 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
return nil
}
func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) {
done := make(chan int)
fp, err := os.CreateTemp(logDir, "ollama-server-*.log")
if err != nil {
return nil, fmt.Errorf("failed to create log file: %w", err)
}
serverLogFile = fp.Name()
cmd := exec.CommandContext(ctx, command, "serve")
cmd.Stderr = fp
cmd.Stdout = fp
go func() {
slog.Info("starting server...")
if err := cmd.Run(); err != nil {
// "signal: killed" expected
if !strings.Contains(err.Error(), "signal") {
slog.Info("failed to run server", "error", err)
}
}
var code int
if cmd.ProcessState != nil {
code = cmd.ProcessState.ExitCode()
}
slog.Info("server exited")
done <- code
}()
return done, nil
}
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
slog.Info("checking status of model", "model", modelName)
showReq := &api.ShowRequest{Name: modelName}
@@ -445,12 +479,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
client, testEndpoint := GetTestEndpoint()
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
serverProcMutex.Lock()
fp, err := os.CreateTemp("", "ollama-server-*.log")
if err != nil {
t.Fatalf("failed to generate log file: %s", err)
}
lifecycle.ServerLogFile = fp.Name()
fp.Close()
if err := startServer(t, ctx, testEndpoint); err != nil {
t.Fatal(err)
}
@@ -478,36 +506,32 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
defer serverProcMutex.Unlock()
if t.Failed() {
fp, err := os.Open(lifecycle.ServerLogFile)
fp, err := os.Open(serverLogFile)
if err != nil {
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
slog.Error("failed to open server log", "logfile", serverLogFile, "error", err)
return
}
defer fp.Close()
data, err := io.ReadAll(fp)
if err != nil {
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
slog.Error("failed to read server log", "logfile", serverLogFile, "error", err)
return
}
slog.Warn("SERVER LOG FOLLOWS")
os.Stderr.Write(data)
slog.Warn("END OF SERVER")
}
err := os.Remove(lifecycle.ServerLogFile)
if err != nil && !os.IsNotExist(err) {
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
}
}
}
}
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatal(err)
}
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second)
}
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
@@ -726,8 +750,14 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
loaded := []string{}
for _, m := range models.Models {
loaded = append(loaded, m.Name)
if m.Name != model {
continue
if strings.Contains(model, ":") {
if m.Name != model {
continue
}
} else if strings.Contains(m.Name, ":") {
if !strings.HasPrefix(m.Name, model+":") {
continue
}
}
gpuPercent := 0
switch {

View File

@@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
if c.swaMemorySize == 0 {
c.swaMemorySize = c.swaWindowSize
}
if int(c.swaMemorySize) > capacity {
// We will allocate space in the cache for the stop token, which won't be part of a follow on
// sequence, so allocate an extra token of storage to ensure that we can jump back without
// causing a cache break. As an optimization, only do this when we have parallel sequences
// because the extra token will live in the batch buffer and won't get overwritten if we
// only have a single sequence.
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
}
if int(c.swaMemorySize) >= capacity {
c.swaMemorySize = math.MaxInt32
}
@@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curLoc, err = c.findStartLoc()
}
if err != nil {
slog.Warn("unable to find a kv cache slot", "cache", c)
return err
}
@@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() {
return
}
type lowestPosition struct {
pos int32
curBatch bool
}
// create a map of unique sequences to the lowest position in that sequence
lowestPos := make(map[int]int32)
lowestPos := make(map[int]lowestPosition)
for i := range c.curPositions {
seq := c.curSequences[i]
pos, ok := lowestPos[seq]
lowest, ok := lowestPos[seq]
if !ok {
pos = c.curPositions[i]
} else if c.curPositions[i] < pos {
pos = c.curPositions[i]
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
} else if c.curPositions[i] < lowest.pos {
lowest.pos = c.curPositions[i]
}
lowestPos[seq] = pos
lowestPos[seq] = lowest
}
// for any sequences are not part of this batch, clean up any tokens
// that are no longer needed after the processing of the previous
// batch
for seq, seqRange := range c.cellRanges {
if _, ok := lowestPos[seq]; !ok {
var last int32
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
last = max(last, c.cells[i].pos)
}
}
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
}
}
// delete any entries that are beyond the window of the oldest position in the sequence
for seq, pos := range lowestPos {
for seq, lowest := range lowestPos {
oldRange, ok := c.cellRanges[seq]
if !ok {
continue
@@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() {
for i := oldRange.min; i <= oldRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
if c.cells[i].pos < pos-c.swaMemorySize {
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
} else {
newRange.min = min(newRange.min, i)
newRange.max = max(newRange.max, i)
}
if c.cells[i].pos >= pos-c.swaWindowSize {
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
c.curCellRange.min = min(c.curCellRange.min, i)
c.curCellRange.max = max(c.curCellRange.max, i)
}
@@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
// for sliding window, check that the window of the new sequence is contained in
// the window of what we are storing
var first int32 = math.MaxInt32
var last int32 = -1
for i := seqRange.min; i <= seqRange.max; i++ {
if slices.Contains(c.cells[i].sequences, seq) {
first = min(first, c.cells[i].pos)
last = max(last, c.cells[i].pos)
}
}
@@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
return false
}
lastWindowStart := max(0, last-c.swaMemorySize)
posWindowStart := max(0, pos-c.swaWindowSize)
return posWindowStart >= lastWindowStart
return posWindowStart >= first && pos <= last+1
}
func (c *Causal) shift(seq int, beginIndex, offset int32) error {

View File

@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
testCache(t, backend, cache, tests)
}
func TestSWASeparateBatches(t *testing.T) {
backend := &testBackend{}
cache := NewSWACache(1, nil)
defer cache.Close()
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
x := float32(math.Inf(-1))
tests := []testCase{
{
name: "First seq 0",
in: []float32{1, 2},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{0, 1},
expected: []float32{1, 2},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 0",
in: []float32{3, 4},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{2, 3},
expected: []float32{2, 3, 4},
expectedShape: []int{1, 1, 3},
expectedMask: []float32{
0, 0, x,
x, 0, 0,
},
},
{
name: "First seq 1",
in: []float32{5, 6},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{0, 1},
expected: []float32{5, 6},
expectedShape: []int{1, 1, 2},
expectedMask: []float32{
0, x,
0, 0,
},
},
{
name: "Second seq 1",
in: []float32{7, 8},
inShape: []int{1, 1, 2},
seqs: []int{1, 1},
pos: []int32{2, 3},
expected: []float32{6, 3, 4, 7, 8},
expectedShape: []int{1, 1, 5},
expectedMask: []float32{
0, x, x, 0, x,
x, x, x, 0, 0,
},
},
{
name: "Third seq 0",
in: []float32{9, 10},
inShape: []int{1, 1, 2},
seqs: []int{0, 0},
pos: []int32{4, 5},
expected: []float32{9, 10, 3, 4},
expectedShape: []int{1, 1, 4},
expectedMask: []float32{
0, x, x, 0,
0, 0, x, x,
},
},
}
testCache(t, backend, cache, tests)
}
func TestSWAMem(t *testing.T) {
backend := &testBackend{}
cache := NewSWAMemCache(1, 3, nil)
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3},
Sequences: []int{0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4},
Sequences: []int{0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
cache.Put(context, tensor, tensor)
// with window size 4, nothing has slid out of the window yet
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
if !cache.CanResume(0, 3) {
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
}
if !cache.CanResume(0, 4) {
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
}
// shift window by adding position 4
// shift window by adding position 5
err = cache.StartForward(context, input.Batch{
Positions: []int32{4, 5},
Sequences: []int{0, 0},
Positions: []int32{5},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
defer context.Close()
err := cache.StartForward(context, input.Batch{
Positions: []int32{0, 1, 2, 3, 4, 5},
Sequences: []int{0, 0, 0, 0, 0, 0},
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
cache.Put(context, tensor, tensor)
// shift window by adding position 6
// shift window by adding position 7
err = cache.StartForward(context, input.Batch{
Positions: []int32{6, 7},
Sequences: []int{0, 0},
Positions: []int32{7},
Sequences: []int{0},
}, false)
if err != nil {
t.Fatalf("StartForward failed: %v", err)
}
cache.SetLayer(0)
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
cache.Put(context, tensor, tensor)
// only the latest position has overlapping windows

View File

@@ -266,11 +266,18 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
}
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
var compute string
if gpus[i].Library == "ROCm" {
compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
} else {
compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
}
slog.Debug("gpu has too little memory to allocate any layers",
"id", gpus[i].ID,
"library", gpus[i].Library,
"variant", gpus[i].Variant,
"compute", gpus[i].Compute,
"compute", compute,
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
"name", gpus[i].Name,
"total", format.HumanBytes2(gpus[i].TotalMemory),

View File

@@ -359,20 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
s.cmd.Stderr = s.status
s.cmd.SysProcAttr = LlamaServerSysProcAttr
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
envWorkarounds := gpus.GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path variable with our adjusted version
pathNeeded := true
ollamaPathNeeded := true
envWorkaroundDone := make([]bool, len(envWorkarounds))
for i := range s.cmd.Env {
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
if strings.EqualFold(cmp[0], pathEnv) {
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
pathNeeded = false
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
ollamaPathNeeded = false
} else if len(envWorkarounds) != 0 {
for j, kv := range envWorkarounds {
tmp := strings.SplitN(kv, "=", 2)
@@ -386,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
if pathNeeded {
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
}
if ollamaPathNeeded {
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
}
for i, done := range envWorkaroundDone {
if !done {
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])

424
middleware/openai.go Normal file
View File

@@ -0,0 +1,424 @@
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *openai.StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsage(chatResponse)
c.Usage = &u
c.Choices = []openai.ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := openai.ToCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &openai.Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := openai.ToUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []openai.CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := openai.FromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req openai.ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := openai.FromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}

928
middleware/openai_test.go Normal file
View File

@@ -0,0 +1,928 @@
package middleware
import (
"bytes"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/openai"
)
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
var (
False = false
True = true
)
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.ChatRequest
err openai.ErrorResponse
}
var capturedRequest *api.ChatRequest
testCases := []testCase{
{
name: "chat handler",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with options",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with streaming usage",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
],
"stream": true,
"stream_options": {"include_usage": true},
"max_tokens": 999,
"seed": 123,
"stop": ["\n", "stop"],
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
"response_format": {"type": "json_object"}
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
"seed": 123.0,
"stop": []any{"\n", "stop"},
"temperature": 3.0,
"frequency_penalty": 4.0,
"presence_penalty": 5.0,
"top_p": 6.0,
},
Format: json.RawMessage(`"json"`),
Stream: &True,
},
},
{
name: "chat handler with image content",
body: `{
"model": "test-model",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
Content: "Let's see what the weather is like in Paris",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and empty content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools and thinking content",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
Thinking: "Let's see what the weather is like in Paris",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "tool response with call ID",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
{
Role: "tool",
Content: "The weather in Paris is 20 degrees Celsius",
ToolName: "get_current_weather",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "tool response with name",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]any{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
{
Role: "tool",
Content: "The weather in Paris is 20 degrees Celsius",
ToolName: "get_current_weather",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with streaming tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris?"}
],
"stream": true,
"tools": [{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"required": ["location"],
"properties": {
"location": {
"type": "string",
"description": "The city and state"
},
"unit": {
"type": "string",
"enum": ["celsius", "fahrenheit"]
}
}
}
}
}]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris?",
},
},
Tools: []api.Tool{
{
Type: "function",
Function: api.ToolFunction{
Name: "get_weather",
Description: "Get the current weather",
Parameters: struct {
Type string `json:"type"`
Defs any `json:"$defs,omitempty"`
Items any `json:"items,omitempty"`
Required []string `json:"required"`
Properties map[string]api.ToolProperty `json:"properties"`
}{
Type: "object",
Required: []string{"location"},
Properties: map[string]api.ToolProperty{
"location": {
Type: api.PropertyType{"string"},
Description: "The city and state",
},
"unit": {
Type: api.PropertyType{"string"},
Enum: []any{"celsius", "fahrenheit"},
},
},
},
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &True,
},
},
{
name: "chat handler error forwarding",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": 2}
]
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
defer func() { capturedRequest = nil }()
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
return
}
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
t.Fatalf("requests did not match: %+v", diff)
}
if diff := cmp.Diff(tc.err, errResp); diff != "" {
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
}
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.GenerateRequest
err openai.ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
name: "completions handler",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &False,
},
},
{
name: "completions handler stream",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler stream with usage",
body: `{
"model": "test-model",
"prompt": "Hello",
"stream": true,
"stream_options": {"include_usage": true},
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 0.8,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &True,
},
},
{
name: "completions handler error forwarding",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": null,
"stop": [1, 2],
"suffix": "suffix"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
name string
body string
req api.EmbedRequest
err openai.ErrorResponse
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
name: "embed handler single input",
body: `{
"input": "Hello",
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello",
Model: "test-model",
},
},
{
name: "embed handler batch input",
body: `{
"input": ["Hello", "World"],
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model",
},
},
{
name: "embed handler error forwarding",
body: `{
"model": "test-model"
}`,
err: openai.ErrorResponse{
Error: openai.Error{
Message: "invalid input",
Type: "invalid_request_error",
},
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var errResp openai.ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
}
func TestListMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "list handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
},
},
})
},
resp: `{
"object": "list",
"data": [
{
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
}
]
}`,
},
{
name: "list handler empty output",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{})
},
resp: `{
"object": "list",
"data": null
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
resp: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}

View File

@@ -150,7 +150,9 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
}
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
scores = scores.Add(ctx, moe.ExpProbsBias)
if moe.ExpProbsBias != nil {
scores = scores.Add(ctx, moe.ExpProbsBias)
}
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
return topKIndices
}

View File

@@ -1,21 +1,18 @@
// openai package provides middleware for partial compatibility with the OpenAI REST API
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"math/rand"
"net/http"
"slices"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/types/model"
)
@@ -86,7 +83,7 @@ type StreamOptions struct {
}
type Reasoning struct {
Effort *string `json:"effort,omitempty"`
Effort string `json:"effort,omitempty"`
}
type ChatCompletionRequest struct {
@@ -220,11 +217,12 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toUsage(r api.ChatResponse) Usage {
// ToUsage converts an api.ChatResponse to Usage
func ToUsage(r api.ChatResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
@@ -256,7 +254,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletion{
Id: id,
@@ -276,12 +275,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}
return nil
}(r.DoneReason),
}}, Usage: toUsage(r),
}}, Usage: ToUsage(r),
DebugInfo: r.DebugInfo,
}
}
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
toolCalls := toToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{
Id: id,
@@ -305,15 +305,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
}
}
func toUsageGenerate(r api.GenerateResponse) Usage {
// ToUsageGenerate converts an api.GenerateResponse to Usage
func ToUsageGenerate(r api.GenerateResponse) Usage {
return Usage{
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
PromptTokens: r.Metrics.PromptEvalCount,
CompletionTokens: r.Metrics.EvalCount,
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
}
}
func toCompletion(id string, r api.GenerateResponse) Completion {
// ToCompletion converts an api.GenerateResponse to Completion
func ToCompletion(id string, r api.GenerateResponse) Completion {
return Completion{
Id: id,
Object: "text_completion",
@@ -330,11 +332,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
return nil
}(r.DoneReason),
}},
Usage: toUsageGenerate(r),
Usage: ToUsageGenerate(r),
}
}
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
return CompletionChunk{
Id: id,
Object: "text_completion",
@@ -354,7 +357,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
}
}
func toListCompletion(r api.ListResponse) ListCompletion {
// ToListCompletion converts an api.ListResponse to ListCompletion
func ToListCompletion(r api.ListResponse) ListCompletion {
var data []Model
for _, m := range r.Models {
data = append(data, Model{
@@ -371,7 +375,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
}
}
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
@@ -396,7 +401,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model {
// ToModel converts an api.ShowResponse to Model
func ToModel(r api.ShowResponse, m string) Model {
return Model{
Id: m,
Object: "model",
@@ -405,7 +411,8 @@ func toModel(r api.ShowResponse, m string) Model {
}
}
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
var messages []api.Message
for _, msg := range r.Messages {
toolName := ""
@@ -560,13 +567,23 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
}
var think *api.ThinkValue
var effort string
if r.Reasoning != nil {
think = &api.ThinkValue{
Value: *r.Reasoning.Effort,
}
effort = r.Reasoning.Effort
} else if r.ReasoningEffort != nil {
think = &api.ThinkValue{
Value: *r.ReasoningEffort,
effort = *r.ReasoningEffort
}
if effort != "" {
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
}
if effort == "none" {
think = &api.ThinkValue{Value: false}
} else {
think = &api.ThinkValue{Value: effort}
}
}
@@ -609,7 +626,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
return apiToolCalls, nil
}
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
options := make(map[string]any)
switch stop := r.Stop.(type) {
@@ -660,413 +678,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
DebugRenderOnly: r.DebugRenderOnly,
}, nil
}
type BaseWriter struct {
gin.ResponseWriter
}
type ChatWriter struct {
stream bool
streamOptions *StreamOptions
id string
toolCallSent bool
BaseWriter
}
type CompleteWriter struct {
stream bool
streamOptions *StreamOptions
id string
BaseWriter
}
type ListWriter struct {
BaseWriter
}
type RetrieveWriter struct {
BaseWriter
model string
}
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
c := toChunk(w.id, chatResponse, w.toolCallSent)
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
w.toolCallSent = true
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsage(chatResponse)
c.Usage = &u
c.Choices = []ChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ChatWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
var generateResponse api.GenerateResponse
err := json.Unmarshal(data, &generateResponse)
if err != nil {
return 0, err
}
// completion chunk
if w.stream {
c := toCompleteChunk(w.id, generateResponse)
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
c.Usage = &Usage{}
}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if generateResponse.Done {
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
u := toUsageGenerate(generateResponse)
c.Usage = &u
c.Choices = []CompleteChunkChoice{}
d, err := json.Marshal(c)
if err != nil {
return 0, err
}
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
}
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *CompleteWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *ListWriter) writeResponse(data []byte) (int, error) {
var listResponse api.ListResponse
err := json.Unmarshal(data, &listResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *ListWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
var showResponse api.ShowResponse
err := json.Unmarshal(data, &showResponse)
if err != nil {
return 0, err
}
// retrieve completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *RetrieveWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
w := &ListWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
}
c.Writer = w
c.Next()
}
}
func RetrieveMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
// response writer
w := &RetrieveWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: c.Param("model"),
}
c.Writer = w
c.Next()
}
}
func CompletionsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req CompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
var b bytes.Buffer
genReq, err := fromCompleteRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &CompleteWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w
c.Next()
}
}
func ChatMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &ChatWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
streamOptions: req.StreamOptions,
}
c.Writer = w
c.Next()
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -179,7 +179,7 @@ function buildROCm() {
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
& cmake --install build --component "HIP" --strip
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
rm -f $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906*
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
}
}
}

View File

@@ -37,8 +37,8 @@ import (
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/logutil"
"github.com/ollama/ollama/middleware"
"github.com/ollama/ollama/model/parsers"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
@@ -330,12 +330,16 @@ func (s *Server) GenerateHandler(c *gin.Context) {
if req.Suffix != "" {
caps = append(caps, model.CapabilityInsert)
}
if req.Think != nil && req.Think.Bool() {
modelCaps := m.Capabilities()
if req.Think != nil {
caps = append(caps, model.CapabilityThinking)
// TODO(drifkin): consider adding a warning if it's false and the model
// doesn't support thinking. It's not strictly required, but it can be a
// hint that the user is on an older qwen3/r1 model that doesn't have an
// updated template supporting thinking
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
req.Think = &api.ThinkValue{Value: true}
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -1449,11 +1453,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
r.POST("/api/embeddings", s.EmbeddingsHandler)
// Inference (OpenAI compatibility)
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
if rc != nil {
// wrap old with new
@@ -1871,8 +1875,16 @@ func (s *Server) ChatHandler(c *gin.Context) {
if len(req.Tools) > 0 {
caps = append(caps, model.CapabilityTools)
}
if req.Think != nil && req.Think.Bool() {
modelCaps := m.Capabilities()
if req.Think != nil {
caps = append(caps, model.CapabilityThinking)
} else {
// add thinking if the model supports it
if slices.Contains(modelCaps, model.CapabilityThinking) {
caps = append(caps, model.CapabilityThinking)
req.Think = &api.ThinkValue{Value: true}
}
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
@@ -1967,88 +1979,167 @@ func (s *Server) ChatHandler(c *gin.Context) {
toolParser = tools.NewParser(m.Template.Template, req.Tools)
}
type structuredOutputsState int
const (
structuredOutputsState_None structuredOutputsState = iota
structuredOutputsState_ReadyToApply
structuredOutputsState_Applying
)
ch := make(chan any)
go func() {
defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: req.Format,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
structuredOutputsState := structuredOutputsState_None
for {
var tb strings.Builder
currentFormat := req.Format
// structured outputs via double request is enabled when:
// 1. the model supports the thinking capability and
// 2. it uses a built-in parser or our generic thinking parser
// Note that the current approach does not work for (potential future)
// non-thinking models that emit anything before actual content. This
// current approach uses the transition from parsed thinking content to
// parsed non-thinking content as the signal to turn constraining on
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
currentFormat = nil
}
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
// sets up new context given parent context per request
ctx, cancel := context.WithCancel(c.Request.Context())
err := r.Completion(ctx, llm.CompletionRequest{
Prompt: prompt,
Images: images,
Format: currentFormat,
Options: opts,
}, func(r llm.CompletionResponse) {
res := api.ChatResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content},
Done: r.Done,
Metrics: api.Metrics{
PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: r.PromptEvalDuration,
EvalCount: r.EvalCount,
EvalDuration: r.EvalDuration,
},
}
if r.Done {
res.DoneReason = r.DoneReason.String()
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
res.Message.Content = content
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
if builtinParser != nil {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Content = remainingContent
res.Message.Thinking = thinkingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.Thinking = thinking
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
tb.WriteString(thinking)
// we are now receiving content from the model - we should start applying structured outputs
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
cancel()
return
}
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
ch <- res
} else {
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
}
return
}
if thinkingState != nil {
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
if thinkingContent == "" && remainingContent == "" && !r.Done {
// need to accumulate more to decide what to send
return
}
res.Message.Thinking = thinkingContent
tb.WriteString(thinkingContent)
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
// to avoid leaking mixed tokens like "</think>Hello"
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
structuredOutputsState = structuredOutputsState_ReadyToApply
res.Message.Content = ""
ch <- res
cancel()
return
}
res.Message.Content = remainingContent
}
if len(req.Tools) > 0 {
toolCalls, content := toolParser.Add(res.Message.Content)
if len(content) > 0 {
res.Message.Content = content
} else if len(toolCalls) > 0 {
res.Message.ToolCalls = toolCalls
res.Message.Content = ""
} else if res.Message.Thinking != "" {
// don't return
} else {
if r.Done {
res.Message.Content = toolParser.Content()
ch <- res
}
return
}
}
ch <- res
})
if err != nil {
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
// only ignores error if it's a context cancellation due to setting structured outputs
} else {
ch <- gin.H{"error": err.Error()}
return
}
}
ch <- res
}); err != nil {
ch <- gin.H{"error": err.Error()}
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
if structuredOutputsState == structuredOutputsState_ReadyToApply {
structuredOutputsState = structuredOutputsState_Applying
msg := api.Message{
Role: "assistant",
Thinking: tb.String(),
}
msgs = append(msgs, msg)
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
if err != nil {
slog.Error("chat prompt error applying structured outputs", "error", err)
ch <- gin.H{"error": err.Error()}
return
}
// force constraining by terminating thinking header, the parser is already at this state
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
// model continue thinking or ending thinking and outputting the final message.
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
}
continue
}
break
}
}()

View File

@@ -1120,13 +1120,6 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
"The answer is 4.",
true)
testChatRequest(t, "thinking disabled but template still adds think tag",
"Simple question",
" My thoughts </think> The answer.",
"",
" My thoughts </think> The answer.",
false)
// Test streaming response with template-added <think>
t.Run("streaming with thinking", func(t *testing.T) {
var wg sync.WaitGroup
@@ -1198,4 +1191,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
}
})
t.Run("structured outputs restart non-stream", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := false
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
var resp api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if resp.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
}
if resp.Message.Content != `{"answer":"42"}` {
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
}
if !resp.Done {
t.Errorf("expected response to be done")
}
if resp.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
}
})
t.Run("structured outputs restart streaming", func(t *testing.T) {
var (
requestsMu sync.Mutex
requests []llm.CompletionRequest
wg sync.WaitGroup
)
wg.Add(2)
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
defer wg.Done()
requestsMu.Lock()
requests = append(requests, r)
callNum := len(requests)
requestsMu.Unlock()
switch callNum {
case 1:
fn(llm.CompletionResponse{
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
Done: false,
PromptEvalCount: 1,
PromptEvalDuration: 1,
})
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(time.Second):
t.Fatalf("timeout waiting for structured outputs cancellation")
return nil
}
case 2:
fn(llm.CompletionResponse{
Content: `{"answer":"42"}`,
Done: true,
DoneReason: llm.DoneReasonStop,
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
})
return nil
default:
t.Fatalf("unexpected number of completion calls: %d", callNum)
return nil
}
}
think := true
streamRequest := true
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-thinking",
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
Think: &api.ThinkValue{Value: think},
Stream: &streamRequest,
Format: format,
})
wg.Wait()
mock.CompletionFn = nil
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
if len(requests) != 2 {
t.Fatalf("expected two completion calls, got %d", len(requests))
}
if requests[0].Format != nil {
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
}
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
t.Errorf("expected second completion format to match original format")
}
decoder := json.NewDecoder(w.Body)
var events []api.ChatResponse
for {
var event api.ChatResponse
if err := decoder.Decode(&event); err == io.EOF {
break
} else if err != nil {
t.Fatal(err)
}
events = append(events, event)
if event.Done {
break
}
}
if len(events) < 2 {
t.Fatalf("expected at least two streaming events, got %d", len(events))
}
first := events[0]
if first.Message.Thinking != "I am thinking through this problem. " {
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
}
if first.Message.Content != "" {
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
}
if first.Done {
t.Error("expected first event to be non-terminal")
}
last := events[len(events)-1]
if last.Message.Thinking != "" {
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
}
if last.Message.Content != `{"answer":"42"}` {
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
}
if !last.Done {
t.Error("expected final event to be done")
}
if last.DoneReason != "stop" {
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
}
})
}

View File

@@ -154,24 +154,55 @@ func TestTemplate(t *testing.T) {
}
func TestParse(t *testing.T) {
cases := []struct {
validCases := []struct {
name string
template string
vars []string
}{
{"{{ .Prompt }}", []string{"prompt", "response"}},
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
{`{{- range .Messages }}
{
name: "PromptOnly",
template: "{{ .Prompt }}",
vars: []string{"prompt", "response"},
},
{
name: "SystemAndPrompt",
template: "{{ .System }} {{ .Prompt }}",
vars: []string{"prompt", "response", "system"},
},
{
name: "PromptResponseSystem",
template: "{{ .System }} {{ .Prompt }} {{ .Response }}",
vars: []string{"prompt", "response", "system"},
},
{
name: "ToolsBlock",
template: "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}",
vars: []string{"prompt", "response", "system", "tools"},
},
{
name: "MessagesRange",
template: "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}",
vars: []string{"content", "messages", "role"},
},
{
name: "ToolResultConditional",
template: "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}",
vars: []string{"content", "messages", "role", "toolname"},
},
{
name: "MultilineSystemUserAssistant",
template: `{{- range .Messages }}
{{- if eq .Role "system" }}SYSTEM:
{{- else if eq .Role "user" }}USER:
{{- else if eq .Role "assistant" }}ASSISTANT:
{{- else if eq .Role "tool" }}TOOL:
{{- else if eq .Role "tool" }}TOOL:
{{- end }} {{ .Content }}
{{- end }}`, []string{"content", "messages", "role"}},
{`{{- if .Messages }}
{{- end }}`,
vars: []string{"content", "messages", "role"},
},
{
name: "ChatMLLike",
template: `{{- if .Messages }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
@@ -182,22 +213,60 @@ func TestParse(t *testing.T) {
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
{{- end -}}`,
vars: []string{"content", "messages", "prompt", "response", "role", "system"},
},
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
for _, tt := range validCases {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
tmpl, err := Parse(tt.template)
if err != nil {
t.Fatal(err)
t.Fatalf("Parse returned unexpected error: %v", err)
}
v, err := tmpl.Vars()
gotVars, err := tmpl.Vars()
if err != nil {
t.Fatal(err)
t.Fatalf("Vars returned unexpected error: %v", err)
}
if diff := cmp.Diff(v, tt.vars); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
if diff := cmp.Diff(gotVars, tt.vars); diff != "" {
t.Errorf("Vars mismatch (-got +want):\n%s", diff)
}
})
}
}
func TestParseError(t *testing.T) {
invalidCases := []struct {
name string
template string
errorStr string
}{
{
"TemplateNotClosed",
"{{ .Prompt ",
"unclosed action",
},
{
"Template",
`{{define "x"}}{{template "x"}}{{end}}{{template "x"}}`,
"undefined template specified",
},
}
for _, tt := range invalidCases {
t.Run(tt.name, func(t *testing.T) {
_, err := Parse(tt.template)
if err == nil {
t.Fatalf("expected Parse to return an error for an invalid template, got nil")
}
if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.errorStr)) {
t.Errorf("unexpected error message.\n got: %q\n want substring (caseinsensitive): %q", err.Error(), tt.errorStr)
}
})
}