Compare commits
45 Commits
v0.12.4-rc
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38ed7c7a4f | ||
|
|
9ff8e5a64d | ||
|
|
6544e14735 | ||
|
|
5db8a818a1 | ||
|
|
6db8da9958 | ||
|
|
0c68ec8d6a | ||
|
|
70d9e363e1 | ||
|
|
1a2feb2a97 | ||
|
|
aab2190420 | ||
|
|
629db9dc43 | ||
|
|
e0cd511661 | ||
|
|
207332078f | ||
|
|
93085127f4 | ||
|
|
c00fa9cc2b | ||
|
|
df411c4b02 | ||
|
|
3d32249c74 | ||
|
|
d681cd7c29 | ||
|
|
47298fce39 | ||
|
|
4a48937ef1 | ||
|
|
967a82f52f | ||
|
|
bbbc73d637 | ||
|
|
15e3611d3d | ||
|
|
77060d462c | ||
|
|
1b91d4dda1 | ||
|
|
7d965258ce | ||
|
|
6a62b894c7 | ||
|
|
90d429f5a8 | ||
|
|
1fc35f1260 | ||
|
|
aa45f7ce27 | ||
|
|
4e5d862ec4 | ||
|
|
303be9304c | ||
|
|
bd15eba4e4 | ||
|
|
bc71278670 | ||
|
|
918231931c | ||
|
|
04c1849878 | ||
|
|
2c2f4deaa9 | ||
|
|
292767afb4 | ||
|
|
ae5e0f0889 | ||
|
|
19e6796eac | ||
|
|
33801c1597 | ||
|
|
e4340667e3 | ||
|
|
2fa1e92a99 | ||
|
|
07e36761c3 | ||
|
|
c29fb007c0 | ||
|
|
730ed6e9e1 |
42
.github/workflows/release.yaml
vendored
42
.github/workflows/release.yaml
vendored
@@ -94,7 +94,7 @@ jobs:
|
||||
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
|
||||
rocm-version: '6.2'
|
||||
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
|
||||
runner_dir: ''
|
||||
runner_dir: 'rocm'
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
env:
|
||||
@@ -163,7 +163,7 @@ jobs:
|
||||
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} -DOLLAMA_RUNNER_DIR="${{ matrix.runner_dir }}"
|
||||
cmake --build --parallel --preset "${{ matrix.preset }}"
|
||||
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip --parallel 8
|
||||
rm -force dist\lib\ollama\rocm\rocblas\library\*gfx906*
|
||||
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
env:
|
||||
CMAKE_GENERATOR: Ninja
|
||||
- uses: actions/upload-artifact@v4
|
||||
@@ -176,19 +176,19 @@ jobs:
|
||||
matrix:
|
||||
os: [windows]
|
||||
arch: [amd64, arm64]
|
||||
include:
|
||||
- os: windows
|
||||
arch: amd64
|
||||
llvmarch: x86_64
|
||||
- os: windows
|
||||
arch: arm64
|
||||
llvmarch: aarch64
|
||||
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
|
||||
environment: release
|
||||
needs: [setup-environment]
|
||||
env:
|
||||
GOFLAGS: ${{ needs.setup-environment.outputs.GOFLAGS }}
|
||||
steps:
|
||||
- name: Install AMD64 system dependencies
|
||||
if: matrix.arch == 'amd64'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Start-Process "C:\msys64\usr\bin\pacman.exe" -ArgumentList @("-S", "--noconfirm", "mingw-w64-clang-x86_64-gcc-compat", "mingw-w64-clang-x86_64-clang") -NoNewWindow -Wait
|
||||
echo "C:\msys64\usr\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
echo "C:\msys64\clang64\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- name: Install ARM64 system dependencies
|
||||
if: matrix.arch == 'arm64'
|
||||
run: |
|
||||
@@ -200,15 +200,29 @@ jobs:
|
||||
|
||||
choco install -y --no-progress git gzip
|
||||
echo "C:\Program Files\Git\cmd" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-aarch64.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt-aarch64.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt-aarch64").path
|
||||
echo $installPath\bin | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- name: Install clang and gcc-compat
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
Set-ExecutionPolicy Bypass -Scope Process -Force
|
||||
Invoke-WebRequest -Uri "https://github.com/mstorsjo/llvm-mingw/releases/download/20240619/llvm-mingw-20240619-ucrt-${{ matrix.llvmarch }}.zip" -OutFile "${{ runner.temp }}\llvm-mingw-ucrt.zip"
|
||||
Expand-Archive -Path ${{ runner.temp }}\llvm-mingw-ucrt.zip -DestinationPath "C:\Program Files\"
|
||||
$installPath=(Resolve-Path -Path "C:\Program Files\llvm-mingw-*-ucrt*").path
|
||||
echo "$installPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
- name: Verify gcc is actually clang
|
||||
run: |
|
||||
$ErrorActionPreference='Continue'
|
||||
$version=& gcc -v 2>&1
|
||||
$version=$version -join "`n"
|
||||
echo "gcc is $version"
|
||||
if ($version -notmatch 'clang') {
|
||||
echo "ERROR: GCC must be clang for proper utf16 handling"
|
||||
exit 1
|
||||
}
|
||||
$ErrorActionPreference='Stop'
|
||||
- run: |
|
||||
go build -o dist/${{ matrix.os }}-${{ matrix.arch }}/ .
|
||||
- uses: actions/upload-artifact@v4
|
||||
|
||||
18
api/types.go
18
api/types.go
@@ -106,6 +106,14 @@ type GenerateRequest struct {
|
||||
// before this option was introduced)
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// Truncate is a boolean that, when set to true, truncates the chat history messages
|
||||
// if the rendered prompt exceeds the context length limit.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Shift is a boolean that, when set to true, shifts the chat history
|
||||
// when hitting the context length limit instead of erroring.
|
||||
Shift *bool `json:"shift,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
@@ -140,6 +148,14 @@ type ChatRequest struct {
|
||||
// for supported models.
|
||||
Think *ThinkValue `json:"think,omitempty"`
|
||||
|
||||
// Truncate is a boolean that, when set to true, truncates the chat history messages
|
||||
// if the rendered prompt exceeds the context length limit.
|
||||
Truncate *bool `json:"truncate,omitempty"`
|
||||
|
||||
// Shift is a boolean that, when set to true, shifts the chat history
|
||||
// when hitting the context length limit instead of erroring.
|
||||
Shift *bool `json:"shift,omitempty"`
|
||||
|
||||
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
|
||||
// template instead of calling the model.
|
||||
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
|
||||
@@ -936,7 +952,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
|
||||
@@ -85,6 +85,19 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
case "scales":
|
||||
mxfp4s[name].scales = t
|
||||
}
|
||||
} else if strings.HasSuffix(t.Name(), "gate_up_exps.bias") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, slices.Collect(splitDim(t, 1,
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "gate_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(0, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
split{
|
||||
Replacer: strings.NewReplacer("gate_up_exps", "up_exps"),
|
||||
slices: []tensor.Slice{nil, tensor.S(1, int(t.Shape()[1]), 2)},
|
||||
},
|
||||
))...)
|
||||
} else {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: t.Name(),
|
||||
@@ -97,17 +110,28 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
|
||||
|
||||
for name, mxfp4 := range mxfp4s {
|
||||
dims := mxfp4.blocks.Shape()
|
||||
|
||||
if !strings.HasSuffix(name, ".weight") {
|
||||
name += ".weight"
|
||||
if strings.Contains(name, "ffn_down_exps") {
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
} else if strings.Contains(name, "ffn_gate_up_exps") {
|
||||
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
|
||||
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
|
||||
}, &ggml.Tensor{
|
||||
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),
|
||||
})
|
||||
}
|
||||
|
||||
out = append(out, &ggml.Tensor{
|
||||
Name: name,
|
||||
Kind: uint32(ggml.TensorTypeMXFP4),
|
||||
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
|
||||
WriterTo: mxfp4,
|
||||
})
|
||||
}
|
||||
|
||||
return out
|
||||
@@ -158,9 +182,21 @@ func (m *gptossModel) Replacements() []string {
|
||||
}
|
||||
|
||||
type mxfp4 struct {
|
||||
slices []tensor.Slice
|
||||
|
||||
blocks, scales Tensor
|
||||
}
|
||||
|
||||
func (m *mxfp4) slice(dim, start, end, step int) *mxfp4 {
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(m.blocks.Shape()))
|
||||
slice[dim] = tensor.S(start, end, step)
|
||||
return &mxfp4{
|
||||
slices: slice,
|
||||
blocks: m.blocks,
|
||||
scales: m.scales,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
var b bytes.Buffer
|
||||
if _, err := m.blocks.WriteTo(&b); err != nil {
|
||||
@@ -204,6 +240,13 @@ func (m *mxfp4) WriteTo(w io.Writer) (int64, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if len(m.slices) > 0 {
|
||||
out, err = out.Slice(m.slices...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
out = tensor.Materialize(out)
|
||||
|
||||
if err := out.Reshape(out.Shape().TotalSize()); err != nil {
|
||||
|
||||
@@ -16,7 +16,8 @@ import (
|
||||
|
||||
type split struct {
|
||||
*strings.Replacer
|
||||
dim int
|
||||
dim int
|
||||
slices []tensor.Slice
|
||||
|
||||
// fn is an optional function to apply to the tensor after slicing
|
||||
fn func(tensor.Tensor) (tensor.Tensor, error)
|
||||
@@ -32,9 +33,12 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] {
|
||||
shape := slices.Clone(t.Shape())
|
||||
shape[dim] = cmp.Or(uint64(split.dim), shape[dim]/uint64(len(splits)))
|
||||
|
||||
slice := slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
slice := split.slices
|
||||
if len(slice) == 0 {
|
||||
slice = slices.Repeat([]tensor.Slice{nil}, len(shape))
|
||||
slice[dim] = tensor.S(offset, offset+int(shape[dim]))
|
||||
offset += int(shape[dim])
|
||||
}
|
||||
|
||||
t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) {
|
||||
dims := make([]int, len(shape))
|
||||
|
||||
@@ -2,11 +2,12 @@ package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -60,17 +61,14 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
|
||||
DependencyPath: dev.LibraryPath,
|
||||
DriverMajor: dev.DriverMajor,
|
||||
DriverMinor: dev.DriverMinor,
|
||||
ComputeMajor: dev.ComputeMajor,
|
||||
ComputeMinor: dev.ComputeMinor,
|
||||
}
|
||||
if dev.Library == "CUDA" || dev.Library == "ROCm" {
|
||||
info.MinimumMemory = 457 * format.MebiByte
|
||||
}
|
||||
if dev.Library == "ROCm" {
|
||||
info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor)
|
||||
if rocmDir != "" {
|
||||
info.DependencyPath = append(info.DependencyPath, rocmDir)
|
||||
}
|
||||
} else {
|
||||
info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor)
|
||||
if dev.Library == "ROCm" && rocmDir != "" {
|
||||
info.DependencyPath = append(info.DependencyPath, rocmDir)
|
||||
}
|
||||
resp = append(resp, info)
|
||||
}
|
||||
@@ -146,3 +144,35 @@ func GetSystemInfo() SystemInfo {
|
||||
GPUs: gpus,
|
||||
}
|
||||
}
|
||||
|
||||
func cudaJetpack() string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
if len(ver) > 0 {
|
||||
return "jetpack" + ver[0]
|
||||
}
|
||||
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
|
||||
r := regexp.MustCompile(` R(\d+) `)
|
||||
m := r.FindSubmatch(data)
|
||||
if len(m) != 2 {
|
||||
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
|
||||
} else {
|
||||
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
|
||||
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
|
||||
// https://developer.nvidia.com/embedded/jetpack-archive
|
||||
switch l4t {
|
||||
case 35:
|
||||
return "jetpack5"
|
||||
case 36:
|
||||
return "jetpack6"
|
||||
default:
|
||||
// Newer Jetson systems use the SBSU runtime
|
||||
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
requested := envconfig.LLMLibrary()
|
||||
jetpack := cudaJetpack()
|
||||
|
||||
// For our initial discovery pass, we gather all the known GPUs through
|
||||
// all the libraries that were detected. This pass may include GPUs that
|
||||
@@ -86,6 +88,14 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
// times concurrently leading to memory contention
|
||||
for dir := range libDirs {
|
||||
var dirs []string
|
||||
if dir != "" {
|
||||
if requested != "" && filepath.Base(dir) != requested {
|
||||
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dir == "" {
|
||||
dirs = []string{LibOllamaPath}
|
||||
} else {
|
||||
@@ -330,6 +340,9 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
}
|
||||
}
|
||||
|
||||
// Apply any iGPU workarounds
|
||||
iGPUWorkarounds(devices)
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
@@ -395,7 +408,7 @@ func (r *bootstrapRunner) HasExited() bool {
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo {
|
||||
// TODO DRY out with llm/server.go
|
||||
slog.Debug("spawing runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
@@ -439,15 +452,18 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||
cmd.Stderr = os.Stderr
|
||||
}
|
||||
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
extraDone := make([]bool, len(extraEnvs))
|
||||
for i := range cmd.Env {
|
||||
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else {
|
||||
for j := range extraEnvs {
|
||||
if extraDone[j] {
|
||||
@@ -464,6 +480,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||
if pathNeeded {
|
||||
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i := range extraDone {
|
||||
if !extraDone[i] {
|
||||
cmd.Env = append(cmd.Env, extraEnvs[i])
|
||||
@@ -540,3 +559,32 @@ func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceIn
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func iGPUWorkarounds(devices []ml.DeviceInfo) {
|
||||
// short circuit if we have no iGPUs
|
||||
anyiGPU := false
|
||||
for i := range devices {
|
||||
if devices[i].Integrated {
|
||||
anyiGPU = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !anyiGPU {
|
||||
return
|
||||
}
|
||||
|
||||
memInfo, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Debug("failed to fetch system memory information for iGPU", "error", err)
|
||||
return
|
||||
}
|
||||
for i := range devices {
|
||||
if !devices[i].Integrated {
|
||||
continue
|
||||
}
|
||||
// NVIDIA iGPUs return useless free VRAM data which ignores system buff/cache
|
||||
if devices[i].Library == "CUDA" {
|
||||
devices[i].FreeMemory = memInfo.FreeMemory
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -37,9 +37,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
|
||||
ComputeMinor int `json:"compute_minor"`
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
@@ -173,7 +174,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7) ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
|
||||
gpu.Library == "ROCm"
|
||||
|
||||
if !supportsFA {
|
||||
|
||||
10
docs/gpu.md
10
docs/gpu.md
@@ -51,11 +51,11 @@ sudo modprobe nvidia_uvm`
|
||||
Ollama supports the following AMD GPUs:
|
||||
|
||||
### Linux Support
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
|
||||
| Family | Cards and accelerators |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
|
||||
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` |
|
||||
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` |
|
||||
|
||||
### Windows Support
|
||||
With ROCm v6.2, the following GPUs are supported on Windows.
|
||||
|
||||
@@ -38,26 +38,14 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||
```
|
||||
Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library.
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||
```
|
||||
|
||||
You can see what features your CPU has with the following.
|
||||
|
||||
```shell
|
||||
cat /proc/cpuinfo| grep flags | head -1
|
||||
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
|
||||
```
|
||||
|
||||
## Installing older or pre-release versions on Linux
|
||||
|
||||
@@ -243,7 +243,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"llama4",
|
||||
"mllama",
|
||||
@@ -870,11 +869,6 @@ func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
if arch := f.KV().Architecture(); slices.Contains([]string{"gptoss", "gpt-oss"}, arch) {
|
||||
// gpt-oss uses attention with sinks which does not support quantized cache types
|
||||
slog.Warn("model only supports non-quantized cache types", "model", arch)
|
||||
return false
|
||||
}
|
||||
return slices.Contains([]string{"q8_0", "q4_0"}, cacheType)
|
||||
}
|
||||
|
||||
|
||||
@@ -17,16 +17,21 @@ func TestBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
@@ -34,10 +39,15 @@ func TestUnicode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
req := api.ChatRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -57,9 +67,14 @@ func TestUnicode(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
@@ -69,9 +84,14 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
req := api.ChatRequest{
|
||||
Model: "gemma2:2b",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Output some smily face emoji",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -83,7 +103,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestUnicodeModelDir(t *testing.T) {
|
||||
@@ -108,14 +128,19 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
)
|
||||
|
||||
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
||||
func TestConcurrentGenerate(t *testing.T) {
|
||||
func TestConcurrentChat(t *testing.T) {
|
||||
// Assumes all requests have the same model
|
||||
req, resp := GenerateRequests()
|
||||
req, resp := ChatRequests()
|
||||
numParallel := int(envconfig.NumParallel() + 1)
|
||||
iterLimit := 3
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestConcurrentGenerate(t *testing.T) {
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ chooseModels:
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
reqs, resps := GenerateRequests()
|
||||
reqs, resps := ChatRequests()
|
||||
for j := 0; j < 3; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
@@ -171,8 +171,8 @@ chooseModels:
|
||||
}
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
|
||||
DoGenerate(ctx, t, client, reqs[k], resps[k],
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
|
||||
DoChat(ctx, t, client, reqs[k], resps[k],
|
||||
120*time.Second, // Be extra patient for the model to load initially
|
||||
10*time.Second, // Once results start streaming, fail if they stall
|
||||
)
|
||||
|
||||
@@ -21,9 +21,14 @@ func TestLongInputContext(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -36,7 +41,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -48,9 +53,14 @@ func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Write me a story in english with a lot of emojis",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -63,12 +73,12 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
modelOverride := "gpt-oss:20b"
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
@@ -155,7 +165,7 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
modelOverride := "gpt-oss:20b"
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
// First run of this scenario on a target system will take a long time to download
|
||||
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
|
||||
func TestLibraryModelsGenerate(t *testing.T) {
|
||||
func TestLibraryModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -43,9 +43,14 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
@@ -58,13 +63,13 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
anyResp = []string{"select", "from"}
|
||||
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
|
||||
anyResp = []string{"yes", "no", "safe", "unsafe"}
|
||||
} else if model == "openthinker" || model == "nexusraven" {
|
||||
} else if model == "openthinker" {
|
||||
anyResp = []string{"plugin", "im_sep", "components", "function call"}
|
||||
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
|
||||
req.Prompt = "def fibonacci():"
|
||||
req.Messages[0].Content = "def fibonacci():"
|
||||
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,17 +34,22 @@ func TestVisionModels(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
req := api.ChatRequest{
|
||||
Model: v.model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "what does the text in this image say?",
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
@@ -56,8 +61,15 @@ func TestVisionModels(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
|
||||
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
func TestModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -66,15 +66,23 @@ func TestModelsGenerate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -128,8 +136,9 @@ func TestModelsEmbed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
req := api.EmbeddingRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
@@ -139,6 +148,10 @@ func TestModelsEmbed(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
|
||||
@@ -173,9 +173,14 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||
continue
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: tc.prompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: tc.prompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
@@ -184,7 +189,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
},
|
||||
}
|
||||
atLeastOne := false
|
||||
var resp api.GenerateResponse
|
||||
var resp api.ChatResponse
|
||||
|
||||
stream := false
|
||||
req.Stream = &stream
|
||||
@@ -198,7 +203,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
||||
err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error {
|
||||
resp = rsp
|
||||
return nil
|
||||
})
|
||||
@@ -214,13 +219,13 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
}
|
||||
loaded = true
|
||||
for _, expResp := range tc.anyResp {
|
||||
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
||||
if strings.Contains(strings.ToLower(resp.Message.Content), expResp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content)
|
||||
}
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -74,9 +74,14 @@ func TestQuantization(t *testing.T) {
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: blueSkyPrompt,
|
||||
chatReq := api.ChatRequest{
|
||||
Model: newName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
@@ -91,8 +96,8 @@ func TestQuantization(t *testing.T) {
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
chatfn := func(response api.ChatResponse) error {
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
@@ -108,14 +113,14 @@ func TestQuantization(t *testing.T) {
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(reqCtx, &genReq, genfn)
|
||||
genErr = client.Chat(reqCtx, &chatReq, chatfn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if genErr != nil && !atLeastOne {
|
||||
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -24,7 +25,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
"mistral-small3.2:latest",
|
||||
@@ -46,6 +47,7 @@ var (
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen2.5vl:3b",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:1.7b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"gemma3:1b",
|
||||
"llama3.1:latest",
|
||||
@@ -265,16 +267,16 @@ var (
|
||||
"Explain the physics involved in them. Be breif in your reply",
|
||||
"Explain the chemistry involved in them. Be breif in your reply",
|
||||
"What are common myths related to them? Be brief in your reply",
|
||||
"What are common fairytales related to them? Be brief in your reply",
|
||||
"Can they form if there is no rain? Be breif in your reply",
|
||||
"Can they form if there are no clouds? Be breif in your reply",
|
||||
"Do they happen on other planets? Be brief in your reply",
|
||||
}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "end", "gold", "fortune", "blessing", "prosperity", "magic", "shower", "sky", "shimmer", "light", "storm", "sunny"}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"}
|
||||
)
|
||||
|
||||
func init() {
|
||||
lifecycle.InitLogging()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
|
||||
if custom != "" {
|
||||
slog.Info("setting default test model to " + custom)
|
||||
@@ -335,6 +337,7 @@ func GetTestEndpoint() (*api.Client, string) {
|
||||
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
var serverLogFile string
|
||||
|
||||
func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
@@ -361,8 +364,9 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
t.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
}
|
||||
|
||||
logDir := t.TempDir()
|
||||
slog.Info("starting server", "url", ollamaHost)
|
||||
done, err := lifecycle.SpawnServer(ctx, "../ollama")
|
||||
done, err := SpawnServer(ctx, "../ollama", logDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
@@ -385,6 +389,36 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
fp, err := os.CreateTemp(logDir, "ollama-server-*.log")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
serverLogFile = fp.Name()
|
||||
|
||||
cmd := exec.CommandContext(ctx, command, "serve")
|
||||
cmd.Stderr = fp
|
||||
cmd.Stdout = fp
|
||||
|
||||
go func() {
|
||||
slog.Info("starting server...")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// "signal: killed" expected
|
||||
if !strings.Contains(err.Error(), "signal") {
|
||||
slog.Info("failed to run server", "error", err)
|
||||
}
|
||||
}
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
slog.Info("server exited")
|
||||
done <- code
|
||||
}()
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
@@ -445,12 +479,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
client, testEndpoint := GetTestEndpoint()
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -478,36 +506,32 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
fp, err := os.Open(serverLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to open server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to read server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err := os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
|
||||
@@ -726,8 +750,14 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
if strings.Contains(model, ":") {
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
} else if strings.Contains(m.Name, ":") {
|
||||
if !strings.HasPrefix(m.Name, model+":") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
|
||||
@@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
if c.swaMemorySize == 0 {
|
||||
c.swaMemorySize = c.swaWindowSize
|
||||
}
|
||||
if int(c.swaMemorySize) > capacity {
|
||||
// We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// only have a single sequence.
|
||||
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
}
|
||||
if int(c.swaMemorySize) >= capacity {
|
||||
c.swaMemorySize = math.MaxInt32
|
||||
}
|
||||
|
||||
@@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("unable to find a kv cache slot", "cache", c)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() {
|
||||
return
|
||||
}
|
||||
|
||||
type lowestPosition struct {
|
||||
pos int32
|
||||
curBatch bool
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
lowestPos := make(map[int]lowestPosition)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
lowest, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
} else if c.curPositions[i] < lowest.pos {
|
||||
lowest.pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
lowestPos[seq] = lowest
|
||||
}
|
||||
|
||||
// for any sequences are not part of this batch, clean up any tokens
|
||||
// that are no longer needed after the processing of the previous
|
||||
// batch
|
||||
for seq, seqRange := range c.cellRanges {
|
||||
if _, ok := lowestPos[seq]; !ok {
|
||||
var last int32
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
}
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
for seq, lowest := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
@@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() {
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
}
|
||||
@@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var first int32 = math.MaxInt32
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
first = min(first, c.cells[i].pos)
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
@@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
return posWindowStart >= first && pos <= last+1
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
|
||||
@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWASeparateBatches(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "First seq 0",
|
||||
in: []float32{1, 2},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{1, 2},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 0",
|
||||
in: []float32{3, 4},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{2, 3, 4},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "First seq 1",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{5, 6},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 1",
|
||||
in: []float32{7, 8},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{6, 3, 4, 7, 8},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0, x,
|
||||
x, x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Third seq 0",
|
||||
in: []float32{9, 10},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{9, 10, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0,
|
||||
0, 0, x, x,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWAMemCache(1, 3, nil)
|
||||
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4},
|
||||
Sequences: []int{0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
if !cache.CanResume(0, 4) {
|
||||
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
// shift window by adding position 5
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{5},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// shift window by adding position 6
|
||||
// shift window by adding position 7
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{6, 7},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{7},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
|
||||
@@ -13,13 +13,13 @@ management libraries for more accurate VRAM usage reporting if available.
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 3 +-
|
||||
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 172 +++++++++++
|
||||
8 files changed, 718 insertions(+), 1 deletion(-)
|
||||
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
|
||||
8 files changed, 755 insertions(+), 1 deletion(-)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index 0a2dae26..a6bf3378 100644
|
||||
index 0a2dae26a..a6bf33785 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -169,6 +169,15 @@ extern "C" {
|
||||
@@ -39,7 +39,7 @@ index 0a2dae26..a6bf3378 100644
|
||||
|
||||
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||
index 33b3a15f..86191ef2 100644
|
||||
index 33b3a15f0..86191ef2c 100644
|
||||
--- a/ggml/src/CMakeLists.txt
|
||||
+++ b/ggml/src/CMakeLists.txt
|
||||
@@ -206,6 +206,8 @@ add_library(ggml-base
|
||||
@@ -52,7 +52,7 @@ index 33b3a15f..86191ef2 100644
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 531d6e27..3fa3a057 100644
|
||||
index 531d6e272..3fa3a0575 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -184,7 +184,7 @@ index 531d6e27..3fa3a057 100644
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 06f9e7c1..eb8f66cb 100644
|
||||
index 06f9e7c1e..eb8f66cb0 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -5,6 +5,9 @@
|
||||
@@ -206,7 +206,7 @@ index 06f9e7c1..eb8f66cb 100644
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index 86a1ebf6..9fc9fbfc 100644
|
||||
index 86a1ebf62..9fc9fbfcf 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -635,6 +635,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||
@@ -225,7 +225,7 @@ index 86a1ebf6..9fc9fbfc 100644
|
||||
}
|
||||
#endif
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
index 08ab4fc9..17999a61 100644
|
||||
index 08ab4fc91..17999a616 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
@@ -247,7 +247,7 @@ index 08ab4fc9..17999a61 100644
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 00000000..8ef19b8c
|
||||
index 000000000..8ef19b8cf
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,449 @@
|
||||
@@ -703,10 +703,10 @@ index 00000000..8ef19b8c
|
||||
\ No newline at end of file
|
||||
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
|
||||
new file mode 100644
|
||||
index 00000000..aa05e9dc
|
||||
index 000000000..c9073cef0
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_nvml.cpp
|
||||
@@ -0,0 +1,172 @@
|
||||
@@ -0,0 +1,209 @@
|
||||
+// NVIDIA Management Library (NVML)
|
||||
+//
|
||||
+// https://developer.nvidia.com/management-library-nvml
|
||||
@@ -721,6 +721,7 @@ index 00000000..aa05e9dc
|
||||
+#include "ggml-impl.h"
|
||||
+#include <filesystem>
|
||||
+#include <mutex>
|
||||
+#include <array>
|
||||
+
|
||||
+#ifdef _WIN32
|
||||
+# define WIN32_LEAN_AND_MEAN
|
||||
@@ -787,6 +788,7 @@ index 00000000..aa05e9dc
|
||||
+ nvmlReturn_t (*nvmlShutdown)(void);
|
||||
+ nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
|
||||
+ nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
+ const char * (*nvmlErrorString)(nvmlReturn_t result);
|
||||
+} nvml { NULL, NULL, NULL, NULL, NULL };
|
||||
+static std::mutex ggml_nvml_lock;
|
||||
+
|
||||
@@ -824,7 +826,8 @@ index 00000000..aa05e9dc
|
||||
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
|
||||
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
|
||||
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
@@ -833,11 +836,45 @@ index 00000000..aa05e9dc
|
||||
+
|
||||
+ SetErrorMode(old_mode);
|
||||
+
|
||||
+ nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
+ return status;
|
||||
+ }
|
||||
+#else
|
||||
+ // Not currently wired up on Linux
|
||||
+ return NVML_ERROR_NOT_SUPPORTED;
|
||||
+ constexpr std::array<const char*, 2> libPaths = {
|
||||
+ "/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
|
||||
+ "libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
|
||||
+ };
|
||||
+ for (const char* path : libPaths) {
|
||||
+ nvml.handle = dlopen(path, RTLD_LAZY);
|
||||
+ if (nvml.handle) break;
|
||||
+ }
|
||||
+ if (nvml.handle == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
|
||||
+ return NVML_ERROR_NOT_FOUND;
|
||||
+ }
|
||||
+ nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
|
||||
+ nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
|
||||
+ nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
|
||||
+ nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
|
||||
+ nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
|
||||
+ if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
+ GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
|
||||
+ dlclose(nvml.handle);
|
||||
+ nvml.handle = NULL;
|
||||
+ return NVML_ERROR_NOT_FOUND;
|
||||
+ }
|
||||
+ nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ dlclose(nvml.handle);
|
||||
+ nvml.handle = NULL;
|
||||
+ return status;
|
||||
+ }
|
||||
+#endif
|
||||
+ int status = nvml.nvmlInit_v2();
|
||||
+ return NVML_SUCCESS;
|
||||
+}
|
||||
+
|
||||
@@ -849,14 +886,14 @@ index 00000000..aa05e9dc
|
||||
+ }
|
||||
+ nvmlReturn_enum status = nvml.nvmlShutdown();
|
||||
+ if (status != NVML_SUCCESS) {
|
||||
+ GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
|
||||
+ GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
+ }
|
||||
+#ifdef _WIN32
|
||||
+ FreeLibrary((HMODULE)(nvml.handle));
|
||||
+ nvml.handle = NULL;
|
||||
+#else
|
||||
+ // Not currently wired up on Linux
|
||||
+ dlclose(nvml.handle);
|
||||
+#endif
|
||||
+ nvml.handle = NULL;
|
||||
+}
|
||||
+
|
||||
+int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
|
||||
|
||||
@@ -266,11 +266,18 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
|
||||
var compute string
|
||||
if gpus[i].Library == "ROCm" {
|
||||
compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
} else {
|
||||
compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
}
|
||||
|
||||
slog.Debug("gpu has too little memory to allocate any layers",
|
||||
"id", gpus[i].ID,
|
||||
"library", gpus[i].Library,
|
||||
"variant", gpus[i].Variant,
|
||||
"compute", gpus[i].Compute,
|
||||
"compute", compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
|
||||
"name", gpus[i].Name,
|
||||
"total", format.HumanBytes2(gpus[i].TotalMemory),
|
||||
|
||||
@@ -359,20 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
s.cmd.Stderr = s.status
|
||||
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
|
||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||
envWorkarounds := gpus.GetVisibleDevicesEnv()
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path variable with our adjusted version
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||
for i := range s.cmd.Env {
|
||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else if len(envWorkarounds) != 0 {
|
||||
for j, kv := range envWorkarounds {
|
||||
tmp := strings.SplitN(kv, "=", 2)
|
||||
@@ -386,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
if pathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i, done := range envWorkaroundDone {
|
||||
if !done {
|
||||
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||
@@ -1374,7 +1379,9 @@ type CompletionRequest struct {
|
||||
Images []ImageData
|
||||
Options *api.Options
|
||||
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Grammar string // set before sending the request to the subprocess
|
||||
Shift bool
|
||||
Truncate bool
|
||||
}
|
||||
|
||||
// DoneReason represents the reason why a completion response is done
|
||||
@@ -1481,7 +1488,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
serverReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
if err != nil && errors.Is(err, context.Canceled) {
|
||||
// client closed connection
|
||||
return err
|
||||
} else if err != nil {
|
||||
slog.Error("post predict", "error", err)
|
||||
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
@@ -1493,7 +1503,7 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
return api.StatusError{StatusCode: res.StatusCode, ErrorMessage: strings.TrimSpace(string(bodyBytes))}
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(res.Body)
|
||||
|
||||
424
middleware/openai.go
Normal file
424
middleware/openai.go
Normal file
@@ -0,0 +1,424 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// chat completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
c := openai.ToCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &openai.Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.CompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq, err := openai.FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := openai.FromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
928
middleware/openai_test.go
Normal file
928
middleware/openai_test.go
Normal file
@@ -0,0 +1,928 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
var (
|
||||
False = false
|
||||
True = true
|
||||
)
|
||||
|
||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "chat handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with image content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "` + prefix + image + `"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Images: []api.ImageData{
|
||||
func() []byte {
|
||||
img, _ := base64.StdEncoding.DecodeString(image)
|
||||
return img
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let's see what the weather is like in Paris",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and empty content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and thinking content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "Let's see what the weather is like in Paris",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool response with call ID",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "The weather in Paris is 20 degrees Celsius",
|
||||
ToolName: "get_current_weather",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool response with name",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "The weather in Paris is 20 degrees Celsius",
|
||||
ToolName: "get_current_weather",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": 2}
|
||||
]
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid message content type: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "completions handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream with usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": null,
|
||||
"stop": [1, 2],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid type for 'stop' field: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.EmbedRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.EmbedRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "embed handler single input",
|
||||
body: `{
|
||||
"input": "Hello",
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: "Hello",
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embed handler batch input",
|
||||
body: `{
|
||||
"input": ["Hello", "World"],
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: []any{"Hello", "World"},
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embed handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid input",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "list handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{
|
||||
Models: []api.ListModelResponse{
|
||||
{
|
||||
Name: "test-model",
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "test-model",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "library"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "list handler empty output",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{})
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": null
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, tc := range testCases {
|
||||
router := gin.New()
|
||||
router.Use(ListMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "retrieve handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ShowResponse{
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
})
|
||||
},
|
||||
resp: `{
|
||||
"id":"test-model",
|
||||
"object":"model",
|
||||
"created":1686935002,
|
||||
"owned_by":"library"}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "retrieve handler error forwarding",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||
},
|
||||
resp: `{
|
||||
"error": {
|
||||
"code": null,
|
||||
"message": "model not found",
|
||||
"param": null,
|
||||
"type": "api_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, tc := range testCases {
|
||||
router := gin.New()
|
||||
router.Use(RetrieveMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
51
ml/backend/ggml/ggml/src/mem_nvml.cpp
vendored
51
ml/backend/ggml/ggml/src/mem_nvml.cpp
vendored
@@ -12,6 +12,7 @@
|
||||
#include "ggml-impl.h"
|
||||
#include <filesystem>
|
||||
#include <mutex>
|
||||
#include <array>
|
||||
|
||||
#ifdef _WIN32
|
||||
# define WIN32_LEAN_AND_MEAN
|
||||
@@ -78,6 +79,7 @@ struct {
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByUUID)(const char *, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
const char * (*nvmlErrorString)(nvmlReturn_t result);
|
||||
} nvml { NULL, NULL, NULL, NULL, NULL };
|
||||
static std::mutex ggml_nvml_lock;
|
||||
|
||||
@@ -115,7 +117,8 @@ int ggml_nvml_init() {
|
||||
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) GetProcAddress((HMODULE)(nvml.handle), "nvmlShutdown");
|
||||
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetHandleByUUID");
|
||||
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) GetProcAddress((HMODULE)(nvml.handle), "nvmlDeviceGetMemoryInfo");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) GetProcAddress((HMODULE)(nvml.handle), "nvmlErrorString");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL || nvml.nvmlErrorString == NULL) {
|
||||
GGML_LOG_INFO("%s unable to locate required symbols in NVML.dll", __func__);
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
@@ -124,11 +127,45 @@ int ggml_nvml_init() {
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
return status;
|
||||
}
|
||||
#else
|
||||
// Not currently wired up on Linux
|
||||
return NVML_ERROR_NOT_SUPPORTED;
|
||||
constexpr std::array<const char*, 2> libPaths = {
|
||||
"/usr/lib/wsl/lib/libnvidia-ml.so.1", // Favor WSL2 path if present
|
||||
"libnvidia-ml.so.1" // On a non-WSL2 system, it should be in the path
|
||||
};
|
||||
for (const char* path : libPaths) {
|
||||
nvml.handle = dlopen(path, RTLD_LAZY);
|
||||
if (nvml.handle) break;
|
||||
}
|
||||
if (nvml.handle == NULL) {
|
||||
GGML_LOG_INFO("%s unable to load libnvidia-ml: %s\n", __func__, dlerror());
|
||||
return NVML_ERROR_NOT_FOUND;
|
||||
}
|
||||
nvml.nvmlInit_v2 = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlInit_v2");
|
||||
nvml.nvmlShutdown = (nvmlReturn_enum (*)()) dlsym(nvml.handle, "nvmlShutdown");
|
||||
nvml.nvmlDeviceGetHandleByUUID = (nvmlReturn_t (*)(const char *, nvmlDevice_t *)) dlsym(nvml.handle, "nvmlDeviceGetHandleByUUID");
|
||||
nvml.nvmlDeviceGetMemoryInfo = (nvmlReturn_t (*)(nvmlDevice_t, nvmlMemory_t *)) dlsym(nvml.handle, "nvmlDeviceGetMemoryInfo");
|
||||
nvml.nvmlErrorString = (const char * (*)(nvmlReturn_enum)) dlsym(nvml.handle, "nvmlErrorString");
|
||||
if (nvml.nvmlInit_v2 == NULL || nvml.nvmlShutdown == NULL || nvml.nvmlDeviceGetHandleByUUID == NULL || nvml.nvmlDeviceGetMemoryInfo == NULL) {
|
||||
GGML_LOG_INFO("%s unable to locate required symbols in libnvidia-ml.so", __func__);
|
||||
dlclose(nvml.handle);
|
||||
nvml.handle = NULL;
|
||||
return NVML_ERROR_NOT_FOUND;
|
||||
}
|
||||
nvmlReturn_t status = nvml.nvmlInit_v2();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s unable to initialize NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
dlclose(nvml.handle);
|
||||
nvml.handle = NULL;
|
||||
return status;
|
||||
}
|
||||
#endif
|
||||
int status = nvml.nvmlInit_v2();
|
||||
return NVML_SUCCESS;
|
||||
}
|
||||
|
||||
@@ -140,14 +177,14 @@ void ggml_nvml_release() {
|
||||
}
|
||||
nvmlReturn_enum status = nvml.nvmlShutdown();
|
||||
if (status != NVML_SUCCESS) {
|
||||
GGML_LOG_INFO("%s failed to shutdown NVML: %d\n", __func__, status);
|
||||
GGML_LOG_INFO("%s failed to shutdown NVML: %s\n", __func__, nvml.nvmlErrorString(status));
|
||||
}
|
||||
#ifdef _WIN32
|
||||
FreeLibrary((HMODULE)(nvml.handle));
|
||||
nvml.handle = NULL;
|
||||
#else
|
||||
// Not currently wired up on Linux
|
||||
dlclose(nvml.handle);
|
||||
#endif
|
||||
nvml.handle = NULL;
|
||||
}
|
||||
|
||||
int ggml_nvml_get_device_memory(const char *uuid, size_t *free, size_t *total) {
|
||||
|
||||
@@ -251,7 +251,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
bts := bts[:n]
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
@@ -266,7 +266,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
}
|
||||
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
_, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
@@ -276,7 +276,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
|
||||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
for b.Loop() {
|
||||
slices.Collect(tokenizer.split(string(bts)))
|
||||
}
|
||||
})
|
||||
|
||||
@@ -150,7 +150,9 @@ func (moe *sparse) Moe(ctx ml.Context, hiddenStates, topKIndices, topKWeights ml
|
||||
}
|
||||
|
||||
func (moe *sparse) topKIndices(ctx ml.Context, scores ml.Tensor, opts *Options) ml.Tensor {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
if moe.ExpProbsBias != nil {
|
||||
scores = scores.Add(ctx, moe.ExpProbsBias)
|
||||
}
|
||||
topKIndices := scores.TopK(ctx, opts.numExpertsUsed)
|
||||
return topKIndices
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []im
|
||||
for i, res := range possibleResolutions {
|
||||
scaleW := float64(res.X) / float64(w)
|
||||
scaleH := float64(res.Y) / float64(h)
|
||||
scale := math.Min(scaleW, scaleH)
|
||||
scale := min(scaleW, scaleH)
|
||||
|
||||
scales[i] = scale
|
||||
}
|
||||
@@ -124,11 +124,11 @@ func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Poi
|
||||
if scaleW < scaleH {
|
||||
newRes = image.Point{
|
||||
targetRes.X,
|
||||
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
int(min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
}
|
||||
} else {
|
||||
newRes = image.Point{
|
||||
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
int(min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
targetRes.Y,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (p ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point) image.Poi
|
||||
tw := min(max(imageSize.X, p.imageSize), canvasSize.X)
|
||||
th := min(max(imageSize.Y, p.imageSize), canvasSize.Y)
|
||||
|
||||
r := math.Min(
|
||||
r := min(
|
||||
float64(tw)/float64(imageSize.X),
|
||||
float64(th)/float64(imageSize.Y),
|
||||
)
|
||||
@@ -89,10 +89,10 @@ func (p ImageProcessor) optimalTiledCanvas(imageSize image.Point) image.Point {
|
||||
if minUpscale == 0 {
|
||||
minUpscale = s
|
||||
} else {
|
||||
minUpscale = math.Min(minUpscale, s)
|
||||
minUpscale = min(minUpscale, s)
|
||||
}
|
||||
} else {
|
||||
maxDownscale = math.Max(maxDownscale, s)
|
||||
maxDownscale = max(maxDownscale, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
219
model/parsers/glm46.go
Normal file
219
model/parsers/glm46.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package parsers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
const (
|
||||
glm46CollectingContent glm46ParserState = iota
|
||||
CollectingThinkingContent
|
||||
CollectingToolContent
|
||||
)
|
||||
|
||||
const (
|
||||
thinkingCloseTag = "</think>"
|
||||
)
|
||||
|
||||
// TODO(gguo): add a field for isThinking
|
||||
type GLM46Parser struct {
|
||||
state qwenParserState
|
||||
buffer strings.Builder
|
||||
tools []api.Tool
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(gguo): changes this to reference an objects param
|
||||
func (p *GLM46Parser) HasThinkingSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
p.tools = tools
|
||||
// p.state = p.initialState()
|
||||
return tools
|
||||
}
|
||||
|
||||
type glm46EventThinkingContent struct {
|
||||
content string
|
||||
}
|
||||
|
||||
func (glm46EventThinkingContent) isGLM46Event() {}
|
||||
|
||||
func (p *GLM46Parser) Add(s string, done bool) (content string, thinking string, calls []api.ToolCall, err error) {
|
||||
p.buffer.WriteString(s)
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case glm46EventRawToolCall:
|
||||
toolCall, err := parseJSONToolCall(event, p.tools)
|
||||
if err != nil {
|
||||
slog.Warn("qwen tool call parsing failed", "error", err)
|
||||
return "", "", nil, err
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case glm46EventThinkingContent:
|
||||
sb.WriteString(event.content)
|
||||
case glm46EventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
sb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) parseEvents() []glm46Event {
|
||||
var all []glm46Event
|
||||
|
||||
keepLooping := true
|
||||
for keepLooping {
|
||||
var events []glm46Event
|
||||
events, keepLooping = p.eat()
|
||||
if len(events) > 0 {
|
||||
all = append(all, events...)
|
||||
}
|
||||
}
|
||||
|
||||
if len(all) > 0 {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "qwen events parsed", "events", all, "state", p.state, "buffer", p.buffer.String())
|
||||
}
|
||||
|
||||
return all
|
||||
}
|
||||
|
||||
func emitContentBeforeTag(p *GLM46Parser, events []glm46Event, tag string) []glm46Event {
|
||||
split := strings.SplitN(p.buffer.String(), tag, 2)
|
||||
before := split[0]
|
||||
before = strings.TrimRightFunc(before, unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventContent{content: before})
|
||||
}
|
||||
after := split[1]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
return events
|
||||
}
|
||||
|
||||
func (p *GLM46Parser) eat() ([]glm46Event, bool) {
|
||||
var events []glm46Event
|
||||
|
||||
switch p.state {
|
||||
case glm46CollectingContent:
|
||||
if strings.Contains(p.buffer.String(), toolOpenTag) {
|
||||
events = emitContentBeforeTag(p, events, toolOpenTag)
|
||||
p.state = glm46CollectingToolContent
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), toolOpenTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 { // why does qwen3coder not have this here
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
case CollectingToolContent:
|
||||
if strings.Contains(p.buffer.String(), glm46ToolCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
events = append(events, glm46EventRawToolCall{raw: before})
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = glm46CollectingContent
|
||||
return events, true
|
||||
} else {
|
||||
return events, false
|
||||
}
|
||||
case glm46CollectingThinkingContent: // so we want to hip the unambiguous stuff
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: before})
|
||||
}
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(after)
|
||||
p.state = glm46CollectingContent
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
} else {
|
||||
whitespaceLen := trailingWhitespaceLen(p.buffer.String())
|
||||
ambiguousStart := len(p.buffer.String()) - whitespaceLen
|
||||
|
||||
unambiguous := p.buffer.String()[:ambiguousStart]
|
||||
ambiguous := p.buffer.String()[ambiguousStart:]
|
||||
p.buffer.Reset()
|
||||
p.buffer.WriteString(ambiguous)
|
||||
if len(unambiguous) > 0 {
|
||||
events = append(events, glm46EventThinkingContent{content: unambiguous})
|
||||
}
|
||||
return events, false
|
||||
}
|
||||
default:
|
||||
panic("unreachable")
|
||||
}
|
||||
}
|
||||
|
||||
func parseJSONToolCall(raw glm46EventRawToolCall, tools []api.Tool) (api.ToolCall, error) {
|
||||
var toolCallFunction api.ToolCallFunction
|
||||
if err := json.Unmarshal([]byte(raw.raw), &toolCallFunction); err != nil {
|
||||
return api.ToolCall{}, err
|
||||
}
|
||||
|
||||
toolCall := api.ToolCall{}
|
||||
toolCall.Function = toolCallFunction
|
||||
|
||||
return toolCall, nil
|
||||
}
|
||||
@@ -21,6 +21,9 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-coder":
|
||||
parser := &Qwen3CoderParser{}
|
||||
return parser
|
||||
case "glm-4.6":
|
||||
parser := &GLM46Parser{}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
|
||||
239
model/renderers/glm46_test.go
Normal file
239
model/renderers/glm46_test.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestGLM46Renderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
messages []api.Message
|
||||
tools []api.Tool
|
||||
thinkValue *api.ThinkValue
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with system message",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant."},
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
You are a helpful assistant.<|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "basic with user assistant user",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "What is the capital of France?"},
|
||||
{Role: "assistant", Thinking: "Let me analyze the request...", Content: "The capital of France is Paris."},
|
||||
{Role: "user", Content: "Fantastic!"},
|
||||
},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
What is the capital of France?<|assistant|>
|
||||
The capital of France is Paris.<|user|>
|
||||
Fantastic!<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "tools",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "tool calls",
|
||||
messages: []api.Message{
|
||||
{Role: "system", Content: "You are a helpful assistant with access to tools."},
|
||||
{Role: "user", Content: "What is the weather like in Tokyo?"},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Tokyo, Japan",
|
||||
"unit": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_weather",
|
||||
Arguments: api.ToolCallFunctionArguments{
|
||||
"location": "Japan",
|
||||
"unit": "fahrenheit",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 22, \"weather\": \"partly cloudy\", \"humidity\": 65}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "{\"temperature\": 68, \"weather\": \"sunny\", \"humidity\": 75}",
|
||||
ToolName: "get_weather",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.",
|
||||
},
|
||||
},
|
||||
tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather in a given location",
|
||||
Parameters: api.ToolFunctionParameters{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
expected: `[gMASK]<sop><|system|>
|
||||
# Tools
|
||||
|
||||
You may call one or more functions to assist with the user query.
|
||||
|
||||
You are provided with function signatures within <tools></tools> XML tags:
|
||||
<tools>
|
||||
{"type":"function","function":{"name":"get_weather","description":"Get the current weather in a given location","parameters":{"type":"object","required":["location"],"properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","description":"","enum":["celsius","fahrenheit"]}}}}}
|
||||
</tools>
|
||||
|
||||
For each function call, output the function name and arguments within the following XML format:
|
||||
<tool_call>{function-name}
|
||||
<arg_key>{arg-key-1}</arg_key>
|
||||
<arg_value>{arg-value-1}</arg_value>
|
||||
<arg_key>{arg-key-2}</arg_key>
|
||||
<arg_value>{arg-value-2}</arg_value>
|
||||
...
|
||||
</tool_call><|system|>
|
||||
You are a helpful assistant with access to tools.<|user|>
|
||||
What is the weather like in Tokyo?<|assistant|>
|
||||
<think></think>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Tokyo, Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>celsius</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>location</arg_key>
|
||||
<arg_value>Japan</arg_value>
|
||||
<arg_key>unit</arg_key>
|
||||
<arg_value>fahrenheit</arg_value>
|
||||
</tool_call><|observation|>
|
||||
<tool_response>
|
||||
{"temperature": 22, "weather": "partly cloudy", "humidity": 65}
|
||||
</tool_response>
|
||||
<tool_response>
|
||||
{"temperature": 68, "weather": "sunny", "humidity": 75}
|
||||
</tool_response><|assistant|>
|
||||
<think></think>
|
||||
The weather in Tokyo is currently partly cloudy with a temperature of 22°C and 65% humidity. It's a pleasant day with moderate temperatures.<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think true",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: true},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?<|assistant|>`,
|
||||
},
|
||||
{
|
||||
name: "think false",
|
||||
messages: []api.Message{
|
||||
{Role: "user", Content: "Hello, how are you?"},
|
||||
},
|
||||
thinkValue: &api.ThinkValue{Value: false},
|
||||
expected: `[gMASK]<sop><|user|>
|
||||
Hello, how are you?/nothink<|assistant|>
|
||||
<think></think>`,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := GLM46Renderer(tt.messages, tt.tools, tt.thinkValue)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if diff := cmp.Diff(rendered, tt.expected); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
t.Logf("Got:\n%s", rendered)
|
||||
t.Logf("Expected:\n%s", tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
109
model/renderers/gml46.go
Normal file
109
model/renderers/gml46.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package renderers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func GLM46Renderer(messages []api.Message, tools []api.Tool, thinkValue *api.ThinkValue) (string, error) {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("[gMASK]<sop>")
|
||||
|
||||
var lastUserIndex int
|
||||
for i, message := range messages {
|
||||
if message.Role == "user" {
|
||||
lastUserIndex = i
|
||||
}
|
||||
}
|
||||
|
||||
if len(tools) > 0 {
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString("# Tools\n\n")
|
||||
sb.WriteString("You may call one or more functions to assist with the user query.\n\n")
|
||||
sb.WriteString("You are provided with function signatures within <tools></tools> XML tags:\n")
|
||||
sb.WriteString("<tools>\n")
|
||||
for _, tool := range tools {
|
||||
d, _ := json.Marshal(tool)
|
||||
sb.WriteString(string(d) + "\n")
|
||||
}
|
||||
sb.WriteString("</tools>\n\n")
|
||||
sb.WriteString("For each function call, output the function name and arguments within the following XML format:\n")
|
||||
sb.WriteString("<tool_call>{function-name}\n")
|
||||
sb.WriteString("<arg_key>{arg-key-1}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-1}</arg_value>\n")
|
||||
sb.WriteString("<arg_key>{arg-key-2}</arg_key>\n")
|
||||
sb.WriteString("<arg_value>{arg-value-2}</arg_value>\n")
|
||||
sb.WriteString("...\n")
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
switch message.Role {
|
||||
case "user":
|
||||
sb.WriteString("<|user|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
if thinkValue != nil && !thinkValue.Bool() && !strings.HasSuffix(message.Content, "/nothink") {
|
||||
sb.WriteString("/nothink")
|
||||
}
|
||||
case "assistant":
|
||||
sb.WriteString("<|assistant|>")
|
||||
if i > lastUserIndex {
|
||||
if message.Thinking != "" {
|
||||
sb.WriteString("\n<think>" + message.Thinking + "</think>")
|
||||
} else {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
}
|
||||
if message.Content != "" {
|
||||
sb.WriteString("\n" + message.Content)
|
||||
}
|
||||
if len(message.ToolCalls) > 0 {
|
||||
for _, toolCall := range message.ToolCalls {
|
||||
sb.WriteString("\n<tool_call>" + toolCall.Function.Name + "\n")
|
||||
for key, value := range toolCall.Function.Arguments {
|
||||
sb.WriteString("<arg_key>" + key + "</arg_key>\n")
|
||||
|
||||
var valueStr string
|
||||
if str, ok := value.(string); ok {
|
||||
valueStr = str
|
||||
} else {
|
||||
jsonBytes, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
valueStr = fmt.Sprintf("%v", value)
|
||||
} else {
|
||||
valueStr = string(jsonBytes)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("<arg_value>" + valueStr + "</arg_value>\n")
|
||||
}
|
||||
|
||||
sb.WriteString("</tool_call>")
|
||||
}
|
||||
}
|
||||
case "tool":
|
||||
if i == 0 || messages[i-1].Role != "tool" {
|
||||
sb.WriteString("<|observation|>")
|
||||
}
|
||||
sb.WriteString("\n<tool_response>\n")
|
||||
sb.WriteString(message.Content)
|
||||
sb.WriteString("\n</tool_response>")
|
||||
case "system":
|
||||
sb.WriteString("<|system|>\n")
|
||||
sb.WriteString(message.Content)
|
||||
}
|
||||
}
|
||||
|
||||
// Add generation prompt
|
||||
sb.WriteString("<|assistant|>")
|
||||
fmt.Println("thinkValue", thinkValue, thinkValue.Bool())
|
||||
if thinkValue != nil && !thinkValue.Bool() {
|
||||
sb.WriteString("\n<think></think>")
|
||||
}
|
||||
|
||||
return sb.String(), nil
|
||||
}
|
||||
@@ -20,6 +20,8 @@ func rendererForName(name string) rendererFunc {
|
||||
switch name {
|
||||
case "qwen3-coder":
|
||||
return Qwen3CoderRenderer
|
||||
case "glm-4.6":
|
||||
return GLM46Renderer
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
|
||||
488
openai/openai.go
488
openai/openai.go
@@ -1,21 +1,18 @@
|
||||
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
||||
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -86,7 +83,7 @@ type StreamOptions struct {
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
@@ -220,11 +217,12 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toUsage(r api.ChatResponse) Usage {
|
||||
// ToUsage converts an api.ChatResponse to Usage
|
||||
func ToUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +254,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
@@ -276,12 +275,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}}, Usage: toUsage(r),
|
||||
}}, Usage: ToUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
||||
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
@@ -305,15 +305,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||
}
|
||||
}
|
||||
|
||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
// ToCompletion converts an api.GenerateResponse to Completion
|
||||
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
@@ -330,11 +332,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: toUsageGenerate(r),
|
||||
Usage: ToUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
||||
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
return CompletionChunk{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
@@ -354,7 +357,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
}
|
||||
}
|
||||
|
||||
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||
var data []Model
|
||||
for _, m := range r.Models {
|
||||
data = append(data, Model{
|
||||
@@ -371,7 +375,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
if r.Embeddings != nil {
|
||||
var data []Embedding
|
||||
for i, e := range r.Embeddings {
|
||||
@@ -396,7 +401,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
return EmbeddingList{}
|
||||
}
|
||||
|
||||
func toModel(r api.ShowResponse, m string) Model {
|
||||
// ToModel converts an api.ShowResponse to Model
|
||||
func ToModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
Id: m,
|
||||
Object: "model",
|
||||
@@ -405,7 +411,8 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||
}
|
||||
}
|
||||
|
||||
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
for _, msg := range r.Messages {
|
||||
toolName := ""
|
||||
@@ -560,13 +567,23 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
var effort string
|
||||
|
||||
if r.Reasoning != nil {
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.Reasoning.Effort,
|
||||
}
|
||||
effort = r.Reasoning.Effort
|
||||
} else if r.ReasoningEffort != nil {
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.ReasoningEffort,
|
||||
effort = *r.ReasoningEffort
|
||||
}
|
||||
|
||||
if effort != "" {
|
||||
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
|
||||
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
|
||||
}
|
||||
|
||||
if effort == "none" {
|
||||
think = &api.ThinkValue{Value: false}
|
||||
} else {
|
||||
think = &api.ThinkValue{Value: effort}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -609,7 +626,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||
return apiToolCalls, nil
|
||||
}
|
||||
|
||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
||||
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
@@ -660,413 +678,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// chat completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
c := toCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
// response writer
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req CompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq, err := fromCompleteRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := fromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -79,13 +79,16 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// shift if context window is exceeded
|
||||
shift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
processingDuration time.Duration
|
||||
generationDuration time.Duration
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -94,13 +97,15 @@ type NewSequenceParams struct {
|
||||
numKeep int
|
||||
samplingParams *llama.SamplingParams
|
||||
embedding bool
|
||||
shift bool
|
||||
truncate bool
|
||||
}
|
||||
|
||||
var errorInputTooLong = errors.New("the input length exceeds the context length")
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -121,6 +126,10 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
if len(inputs) > s.cache.numCtx {
|
||||
discard := len(inputs) - s.cache.numCtx
|
||||
if !params.truncate {
|
||||
return nil, errorInputTooLong
|
||||
}
|
||||
|
||||
newInputs := inputs[:params.numKeep]
|
||||
newInputs = append(newInputs, inputs[params.numKeep+discard:]...)
|
||||
|
||||
@@ -142,18 +151,17 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -388,6 +396,11 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
for i, input := range seq.inputs {
|
||||
if len(seq.cache.Inputs)+len(seq.pendingInputs)+1 > s.cache.numCtx {
|
||||
if len(seq.pendingInputs) == 0 {
|
||||
if !seq.shift {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
break
|
||||
}
|
||||
|
||||
err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -438,8 +451,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
t := time.Now()
|
||||
if err := s.lc.Decode(batch); err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
@@ -459,9 +472,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numDecoded += 1
|
||||
if seq.numDecoded == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
s.lc.Synchronize()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
} else {
|
||||
seq.processingDuration += time.Since(t)
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -583,8 +599,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
numKeep: req.Options.NumKeep,
|
||||
samplingParams: &samplingParams,
|
||||
embedding: false,
|
||||
shift: req.Shift,
|
||||
truncate: req.Truncate,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -646,9 +668,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.generationDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -88,13 +88,17 @@ type Sequence struct {
|
||||
// true if an embedding are to be returned instead of text generation
|
||||
embeddingOnly bool
|
||||
|
||||
// shift if context window is exceeded
|
||||
shift bool
|
||||
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
startedAt, lastUpdatedAt time.Time
|
||||
processingDuration time.Duration
|
||||
samplingDuration time.Duration
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -103,13 +107,15 @@ type NewSequenceParams struct {
|
||||
numKeep int32
|
||||
sampler sample.Sampler
|
||||
embedding bool
|
||||
shift bool
|
||||
truncate bool
|
||||
}
|
||||
|
||||
var errorInputTooLong = errors.New("the input length exceeds the context length")
|
||||
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -126,6 +132,11 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
|
||||
if int32(len(inputs)) > s.cache.numCtx {
|
||||
discard := int32(len(inputs)) - s.cache.numCtx
|
||||
|
||||
if !params.truncate {
|
||||
return nil, errorInputTooLong
|
||||
}
|
||||
|
||||
promptStart := params.numKeep + discard
|
||||
|
||||
// If we need to truncate in the middle of a unbreakable batch, remove the entire batch
|
||||
@@ -164,20 +175,20 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
shift: params.shift,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -323,9 +334,6 @@ type Server struct {
|
||||
// TODO (jmorganca): make this n_batch
|
||||
batchSize int
|
||||
|
||||
// Used to signal a hard failure during async processing which will panic the runner
|
||||
hardErrCh chan error
|
||||
|
||||
// Simple counter used only for trace logging batches
|
||||
batchID int
|
||||
|
||||
@@ -408,25 +416,25 @@ func (s *Server) run(ctx context.Context) {
|
||||
|
||||
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||
|
||||
var activeBatch batchState
|
||||
var previousBatch batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case err := <-s.hardErrCh:
|
||||
panic(err)
|
||||
default:
|
||||
var err error
|
||||
activeBatch, err = s.forwardBatch(activeBatch)
|
||||
nextBatch, err := s.forwardBatch(previousBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
go s.computeBatch(nextBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
s.computeBatch(nextBatch)
|
||||
}
|
||||
|
||||
previousBatch = nextBatch
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -522,6 +530,12 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
break
|
||||
}
|
||||
|
||||
if !seq.shift {
|
||||
s.removeSequence(seqIdx, llm.DoneReasonLength)
|
||||
nextBatch.seqs[seqIdx] = nil
|
||||
break
|
||||
}
|
||||
|
||||
err = s.cache.ShiftCacheSlot(seq.cache, seq.numKeep)
|
||||
if err != nil {
|
||||
var reprocess *ErrReprocessInputs
|
||||
@@ -562,6 +576,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
for i := range nextBatch.seqs {
|
||||
if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
|
||||
nextBatch.seqs[i].startedAt = startedAt
|
||||
}
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
@@ -656,9 +677,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
if !s.cache.enabled {
|
||||
s.hardErrCh <- fmt.Errorf("caching disabled but unable to fit entire input in a batch")
|
||||
s.mu.Unlock()
|
||||
return
|
||||
panic("caching disabled but unable to fit entire input in a batch")
|
||||
}
|
||||
continue
|
||||
}
|
||||
@@ -682,6 +701,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
activeBatch.modelOutput)
|
||||
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
t := time.Now()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
@@ -694,8 +714,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -710,8 +732,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
logutil.Trace("computeBatch: vocab details", "batchID", activeBatch.id, "seqIdx", i, "len(logits)", len(outputs), "len(activeBatch.batch.Outputs)", activeBatch.batch.Outputs.Dim(0), "vocabSize", vocabSize, "iBatches", iBatches)
|
||||
token, err := seq.sampler.Sample(outputs[iBatches[i]*vocabSize : (iBatches[i]+1)*vocabSize])
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to sample token: %w", err)
|
||||
return
|
||||
panic("failed to sample token")
|
||||
}
|
||||
|
||||
nextBatchTokens[i].Token = token
|
||||
@@ -728,8 +749,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
|
||||
piece, err := s.model.(model.TextProcessor).Decode([]int32{token})
|
||||
if err != nil {
|
||||
s.hardErrCh <- fmt.Errorf("failed to decode token: %w", err)
|
||||
return
|
||||
panic("failed to decode token")
|
||||
}
|
||||
|
||||
seq.pendingResponses = append(seq.pendingResponses, piece)
|
||||
@@ -774,6 +794,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
samplingDuration := time.Since(t)
|
||||
for i, seq := range s.seqs {
|
||||
if seq != nil && nextBatchTokens[i] != nil {
|
||||
s.seqs[i].samplingDuration += samplingDuration
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -824,8 +851,14 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
numKeep: int32(req.Options.NumKeep),
|
||||
sampler: sampler,
|
||||
embedding: false,
|
||||
shift: req.Shift,
|
||||
truncate: req.Truncate,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -887,9 +920,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
@@ -1304,7 +1337,6 @@ func Execute(args []string) error {
|
||||
server := &Server{
|
||||
modelPath: *mpath,
|
||||
status: llm.ServerStatusLaunched,
|
||||
hardErrCh: make(chan error, 1),
|
||||
}
|
||||
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
@@ -179,7 +179,7 @@ function buildROCm() {
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
rm -f $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906*
|
||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,7 +20,7 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
|
||||
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
|
||||
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
|
||||
// latest message and 2) system messages
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue) (prompt string, images []llm.ImageData, _ error) {
|
||||
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool, think *api.ThinkValue, truncate bool) (prompt string, images []llm.ImageData, _ error) {
|
||||
var system []api.Message
|
||||
|
||||
// TODO: Ideally we would compute this from the projector metadata but some pieces are implementation dependent
|
||||
@@ -59,7 +59,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
}
|
||||
}
|
||||
|
||||
if ctxLen > opts.NumCtx {
|
||||
if truncate && ctxLen > opts.NumCtx {
|
||||
slog.Debug("truncating input messages which exceed context length", "truncated", len(msgs[i:]))
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -27,16 +27,18 @@ func TestChatPrompt(t *testing.T) {
|
||||
visionModel := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model Model
|
||||
limit int
|
||||
msgs []api.Message
|
||||
name string
|
||||
model Model
|
||||
limit int
|
||||
truncate bool
|
||||
msgs []api.Message
|
||||
expect
|
||||
}{
|
||||
{
|
||||
name: "messages",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
name: "messages",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -47,9 +49,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages",
|
||||
model: visionModel,
|
||||
limit: 1,
|
||||
name: "truncate messages",
|
||||
model: visionModel,
|
||||
limit: 1,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -60,9 +63,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with image",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
name: "truncate messages with image",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -76,9 +80,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate messages with images",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
name: "truncate messages with images",
|
||||
model: visionModel,
|
||||
limit: 64,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -92,9 +97,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "messages with images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -109,9 +115,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with image tag",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "message with image tag",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry! [img]", Images: []api.ImageData{[]byte("something")}},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -126,9 +133,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "messages with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "messages with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
@@ -145,9 +153,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "truncate message with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 1024,
|
||||
name: "truncate message with interleaved images",
|
||||
model: visionModel,
|
||||
limit: 1024,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "user", Images: []api.ImageData{[]byte("something")}},
|
||||
@@ -163,9 +172,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "message with system prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "message with system prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "system", Content: "You are the Test Who Lived."},
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
@@ -177,9 +187,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "out of order system",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "out of order system",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
@@ -191,9 +202,10 @@ func TestChatPrompt(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple images same prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
name: "multiple images same prompt",
|
||||
model: visionModel,
|
||||
limit: 2048,
|
||||
truncate: true,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Compare these two pictures of hotdogs", Images: []api.ImageData{[]byte("one hotdog"), []byte("two hotdogs")}},
|
||||
},
|
||||
@@ -202,6 +214,20 @@ func TestChatPrompt(t *testing.T) {
|
||||
images: [][]byte{[]byte("one hotdog"), []byte("two hotdogs")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no truncate with limit exceeded",
|
||||
model: visionModel,
|
||||
limit: 10,
|
||||
truncate: false,
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "You're a test, Harry!"},
|
||||
{Role: "assistant", Content: "I-I'm a what?"},
|
||||
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
},
|
||||
expect: expect{
|
||||
prompt: "You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
@@ -209,7 +235,7 @@ func TestChatPrompt(t *testing.T) {
|
||||
model := tt.model
|
||||
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
|
||||
think := false
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think})
|
||||
prompt, images, err := chatPrompt(t.Context(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil, &api.ThinkValue{Value: think}, tt.truncate)
|
||||
if tt.error == nil && err != nil {
|
||||
t.Fatal(err)
|
||||
} else if tt.error != nil && err != tt.error {
|
||||
|
||||
348
server/routes.go
348
server/routes.go
@@ -37,8 +37,8 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -330,12 +330,18 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
// TODO(drifkin): consider adding a warning if it's false and the model
|
||||
// doesn't support thinking. It's not strictly required, but it can be a
|
||||
// hint that the user is on an older qwen3/r1 model that doesn't have an
|
||||
// updated template supporting thinking
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -397,12 +403,11 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
msgs = append(msgs, m.Messages...)
|
||||
}
|
||||
|
||||
userMsg := api.Message{Role: "user", Content: req.Prompt}
|
||||
for _, i := range images {
|
||||
imgPrompt := ""
|
||||
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]"+imgPrompt, i.ID)})
|
||||
userMsg.Images = append(userMsg.Images, i.Data)
|
||||
}
|
||||
|
||||
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
|
||||
values.Messages = append(msgs, userMsg)
|
||||
}
|
||||
|
||||
values.Think = req.Think != nil && req.Think.Bool()
|
||||
@@ -423,12 +428,31 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
b.WriteString(s)
|
||||
}
|
||||
|
||||
if err := tmpl.Execute(&b, values); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// check that we're in the `api/chat`-like flow, and if so, generate the
|
||||
// prompt the same way
|
||||
// TEMP(drifkin): we should really just detect the chat-like flow and call
|
||||
// the real chat handler, but doing this as a stopgap to get renderer
|
||||
// support for generate
|
||||
if values.Messages != nil && values.Suffix == "" && req.Template == "" {
|
||||
prompt, images, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, values.Messages, []api.Tool{}, req.Think, req.Truncate == nil || *req.Truncate)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
// TEMP(drifkin): req.Context will be removed very soon, but we're temporarily supporting it in this flow here
|
||||
if req.Context != nil {
|
||||
b.WriteString(prompt)
|
||||
prompt = b.String()
|
||||
}
|
||||
} else {
|
||||
// legacy flow
|
||||
if err := tmpl.Execute(&b, values); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
prompt = b.String()
|
||||
prompt = b.String()
|
||||
}
|
||||
}
|
||||
|
||||
// If debug mode is enabled, return the rendered template instead of calling the model
|
||||
@@ -464,10 +488,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
var sb strings.Builder
|
||||
defer close(ch)
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: req.Truncate == nil || *req.Truncate,
|
||||
}, func(cr llm.CompletionResponse) {
|
||||
res := api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
@@ -529,7 +555,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
||||
} else {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -549,7 +580,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
status, ok := t["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
@@ -1449,11 +1485,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1614,6 +1650,30 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
return false
|
||||
}
|
||||
|
||||
// errors are provided as a gin.H with an "error" field and
|
||||
// an optional "status" field. For errors that are streamed
|
||||
// before any content, we need to set the status code and
|
||||
// content type for the error.
|
||||
if h, ok := val.(gin.H); ok {
|
||||
if e, ok := h["error"].(string); ok {
|
||||
status, ok := h["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
if !c.Writer.Written() {
|
||||
c.Header("Content-Type", "application/json")
|
||||
c.JSON(status, gin.H{"error": e})
|
||||
} else {
|
||||
if err := json.NewEncoder(c.Writer).Encode(gin.H{"error": e}); err != nil {
|
||||
slog.Error("streamResponse failed to encode json error", "error", err)
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
bts, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
|
||||
@@ -1871,8 +1931,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -1923,7 +1993,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||
truncate := req.Truncate == nil || *req.Truncate
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1967,88 +2038,174 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
type structuredOutputsState int
|
||||
const (
|
||||
structuredOutputsState_None structuredOutputsState = iota
|
||||
structuredOutputsState_ReadyToApply
|
||||
structuredOutputsState_Applying
|
||||
)
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
structuredOutputsState := structuredOutputsState_None
|
||||
|
||||
for {
|
||||
var tb strings.Builder
|
||||
|
||||
currentFormat := req.Format
|
||||
// structured outputs via double request is enabled when:
|
||||
// 1. the model supports the thinking capability and
|
||||
// 2. it uses a built-in parser or our generic thinking parser
|
||||
|
||||
// Note that the current approach does not work for (potential future)
|
||||
// non-thinking models that emit anything before actual content. This
|
||||
// current approach uses the transition from parsed thinking content to
|
||||
// parsed non-thinking content as the signal to turn constraining on
|
||||
|
||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||
currentFormat = nil
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
// sets up new context given parent context per request
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(ctx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
Shift: req.Shift == nil || *req.Shift,
|
||||
Truncate: truncate,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
|
||||
tb.WriteString(thinking)
|
||||
// we are now receiving content from the model - we should start applying structured outputs
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Thinking = thinkingContent
|
||||
tb.WriteString(thinkingContent)
|
||||
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
|
||||
// to avoid leaking mixed tokens like "</think>Hello"
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
res.Message.Content = ""
|
||||
ch <- res
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
})
|
||||
if err != nil {
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
|
||||
// only ignores error if it's a context cancellation due to setting structured outputs
|
||||
} else {
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
ch <- gin.H{"error": serr.ErrorMessage, "status": serr.StatusCode}
|
||||
} else {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply {
|
||||
structuredOutputsState = structuredOutputsState_Applying
|
||||
msg := api.Message{
|
||||
Role: "assistant",
|
||||
Thinking: tb.String(),
|
||||
}
|
||||
|
||||
msgs = append(msgs, msg)
|
||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think, truncate)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
// force constraining by terminating thinking header, the parser is already at this state
|
||||
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
|
||||
// model continue thinking or ending thinking and outputting the final message.
|
||||
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
|
||||
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
|
||||
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -2072,7 +2229,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
msg = "unexpected error format in response"
|
||||
}
|
||||
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": msg})
|
||||
status, ok := t["status"].(int)
|
||||
if !ok {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
|
||||
c.JSON(status, gin.H{"error": msg})
|
||||
return
|
||||
default:
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected response"})
|
||||
|
||||
@@ -146,7 +146,7 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
DebugRenderOnly: true,
|
||||
},
|
||||
expectDebug: true,
|
||||
expectTemplate: "[img-0]\n\nDescribe this image",
|
||||
expectTemplate: "[img-0]Describe this image",
|
||||
expectNumImages: 1,
|
||||
},
|
||||
{
|
||||
|
||||
313
server/routes_generate_renderer_test.go
Normal file
313
server/routes_generate_renderer_test.go
Normal file
@@ -0,0 +1,313 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
|
||||
// when in chat-like flow (messages present, no suffix, no template)
|
||||
func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer (qwen3-coder)
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
// Create a model with the qwen3-coder renderer
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
mock.CompletionResponse.Content = "Hi!"
|
||||
|
||||
t.Run("chat-like flow uses renderer", func(t *testing.T) {
|
||||
// Test that when using messages (chat-like flow), the built-in renderer is used
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// The qwen3-coder renderer produces output with <|im_start|> and <|im_end|> tags
|
||||
// When messages are built internally from prompt, it should use the renderer
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to contain <|im_start|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_end|>") {
|
||||
t.Errorf("expected prompt to contain <|im_end|> from qwen3-coder renderer, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat-like flow with system message uses renderer", func(t *testing.T) {
|
||||
// Test that system messages work with the renderer
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a helpful coding assistant.",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should contain the system message and use renderer format
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>system") {
|
||||
t.Errorf("expected prompt to contain system message with renderer format, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
if !strings.Contains(mock.CompletionRequest.Prompt, "You are a helpful coding assistant.") {
|
||||
t.Errorf("expected prompt to contain system message content, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("custom template bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a custom template uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
Template: "{{ .Prompt }}",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when custom template is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when custom template provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should just be the raw prompt from the template
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Write a hello world function"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
// Create a model with suffix support for the next test
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
From: "test-renderer",
|
||||
Template: `{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
|
||||
{{- else }}{{ .Prompt }}
|
||||
{{- end }}`,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("suffix bypasses renderer", func(t *testing.T) {
|
||||
// Test that providing a suffix uses the legacy flow
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-suffix-renderer",
|
||||
Prompt: "def add(",
|
||||
Suffix: " return c",
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
// Should NOT use the renderer format when suffix is provided
|
||||
if strings.Contains(mock.CompletionRequest.Prompt, "<|im_start|>") {
|
||||
t.Errorf("expected prompt to NOT use renderer when suffix provided, got: %s", mock.CompletionRequest.Prompt)
|
||||
}
|
||||
|
||||
// Should use the suffix template format
|
||||
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestGenerateWithDebugRenderOnly tests that debug_render_only works with built-in renderers
|
||||
func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
mock := mockRunner{
|
||||
CompletionResponse: llm.CompletionResponse{
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
return false
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
go s.sched.Run(t.Context())
|
||||
|
||||
// Create a model with a built-in renderer
|
||||
_, digest := createBinFile(t, ggml.KV{
|
||||
"general.architecture": "qwen3",
|
||||
"qwen3.block_count": uint32(1),
|
||||
"qwen3.context_length": uint32(8192),
|
||||
"qwen3.embedding_length": uint32(4096),
|
||||
"qwen3.attention.head_count": uint32(32),
|
||||
"qwen3.attention.head_count_kv": uint32(8),
|
||||
"tokenizer.ggml.tokens": []string{""},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, []*ggml.Tensor{
|
||||
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
|
||||
})
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Files: map[string]string{"file.gguf": digest},
|
||||
Renderer: "qwen3-coder",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("debug_render_only with renderer", func(t *testing.T) {
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test-debug-renderer",
|
||||
Prompt: "Write a hello world function",
|
||||
System: "You are a coding assistant",
|
||||
DebugRenderOnly: true,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp api.GenerateResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.DebugInfo == nil {
|
||||
t.Fatalf("expected debug info, got nil")
|
||||
}
|
||||
|
||||
// Verify that the rendered template uses the built-in renderer
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "<|im_start|>") {
|
||||
t.Errorf("expected rendered template to use qwen3-coder renderer format, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "You are a coding assistant") {
|
||||
t.Errorf("expected rendered template to contain system message, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.DebugInfo.RenderedTemplate, "Write a hello world function") {
|
||||
t.Errorf("expected rendered template to contain prompt, got: %s", resp.DebugInfo.RenderedTemplate)
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -158,11 +158,26 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"\"test\" does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model can't think but think set false", func(t *testing.T) {
|
||||
think := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -594,6 +609,58 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("final tool call mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error non-streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Status: "Service Unavailable",
|
||||
ErrorMessage: "model is overloaded",
|
||||
}
|
||||
}
|
||||
|
||||
stream := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Status: "Too Many Requests",
|
||||
ErrorMessage: "rate limit exceeded",
|
||||
}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected status 429, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerate(t *testing.T) {
|
||||
@@ -968,6 +1035,55 @@ func TestGenerate(t *testing.T) {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error non-streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusServiceUnavailable,
|
||||
Status: "Service Unavailable",
|
||||
ErrorMessage: "model is overloaded",
|
||||
}
|
||||
}
|
||||
|
||||
streamRequest := false
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &streamRequest,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusServiceUnavailable {
|
||||
t.Errorf("expected status 503, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"model is overloaded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("status error streaming", func(t *testing.T) {
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
return api.StatusError{
|
||||
StatusCode: http.StatusTooManyRequests,
|
||||
Status: "Too Many Requests",
|
||||
ErrorMessage: "rate limit exceeded",
|
||||
}
|
||||
}
|
||||
|
||||
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
|
||||
Model: "test",
|
||||
Prompt: "Hello!",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusTooManyRequests {
|
||||
t.Errorf("expected status 429, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"rate limit exceeded"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
@@ -1120,13 +1236,6 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
"The answer is 4.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking disabled but template still adds think tag",
|
||||
"Simple question",
|
||||
" My thoughts </think> The answer.",
|
||||
"",
|
||||
" My thoughts </think> The answer.",
|
||||
false)
|
||||
|
||||
// Test streaming response with template-added <think>
|
||||
t.Run("streaming with thinking", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
@@ -1198,4 +1307,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart non-stream", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
|
||||
}
|
||||
|
||||
if resp.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected response to be done")
|
||||
}
|
||||
|
||||
if resp.DoneReason != "stop" {
|
||||
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart streaming", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var events []api.ChatResponse
|
||||
for {
|
||||
var event api.ChatResponse
|
||||
if err := decoder.Decode(&event); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
events = append(events, event)
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("expected at least two streaming events, got %d", len(events))
|
||||
}
|
||||
|
||||
first := events[0]
|
||||
if first.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
|
||||
}
|
||||
|
||||
if first.Message.Content != "" {
|
||||
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
|
||||
}
|
||||
|
||||
if first.Done {
|
||||
t.Error("expected first event to be non-terminal")
|
||||
}
|
||||
|
||||
last := events[len(events)-1]
|
||||
if last.Message.Thinking != "" {
|
||||
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
|
||||
}
|
||||
|
||||
if last.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
|
||||
}
|
||||
|
||||
if !last.Done {
|
||||
t.Error("expected final event to be done")
|
||||
}
|
||||
|
||||
if last.DoneReason != "stop" {
|
||||
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -229,8 +230,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
// Shouildn't happen
|
||||
slog.Error("runner to expire was nil!")
|
||||
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
|
||||
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
|
||||
slog.Debug("runner to expire was nil, retrying")
|
||||
continue
|
||||
}
|
||||
// Trigger an expiration to unload once it's done
|
||||
@@ -644,27 +646,35 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
|
||||
totalMemoryBefore += gpu.TotalMemory
|
||||
freeMemoryBefore += gpu.FreeMemory
|
||||
}
|
||||
totalMemoryNow := totalMemoryBefore
|
||||
freeMemoryNow := freeMemoryBefore
|
||||
|
||||
go func() {
|
||||
expiresAt := start.Add(5 * time.Second) // typical convergence is 0.5-1.5s
|
||||
// typical convergence is 0.5-1.5s - If it takes more than 5 seconds to discover and converge, let the scheduler estimate VRAM usage
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
<-ticker.C
|
||||
if time.Now().After(expiresAt) {
|
||||
slog.Warn("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
}
|
||||
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := s.getGpuFn(context.Background(), runners)
|
||||
var totalMemoryNow, freeMemoryNow uint64
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
// If we're within ~80% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner)
|
||||
select {
|
||||
case <-ticker.C:
|
||||
// Query GPUs, look for free to go back up
|
||||
gpusNow := s.getGpuFn(ctx, runners)
|
||||
totalMemoryNow = 0
|
||||
freeMemoryNow = 0
|
||||
for _, gpu := range gpusNow {
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
|
||||
// If we're within ~75% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
slog.Debug("gpu VRAM usage didn't recover within timeout", "seconds", time.Since(start).Seconds(), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
finished <- struct{}{}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -154,24 +154,55 @@ func TestTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
cases := []struct {
|
||||
validCases := []struct {
|
||||
name string
|
||||
template string
|
||||
vars []string
|
||||
}{
|
||||
{"{{ .Prompt }}", []string{"prompt", "response"}},
|
||||
{"{{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system"}},
|
||||
{"{{ .System }} {{ .Prompt }} {{ .Response }}", []string{"prompt", "response", "system"}},
|
||||
{"{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}", []string{"prompt", "response", "system", "tools"}},
|
||||
{"{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}", []string{"content", "messages", "role"}},
|
||||
{"{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}", []string{"content", "messages", "role", "toolname"}},
|
||||
{`{{- range .Messages }}
|
||||
{
|
||||
name: "PromptOnly",
|
||||
template: "{{ .Prompt }}",
|
||||
vars: []string{"prompt", "response"},
|
||||
},
|
||||
{
|
||||
name: "SystemAndPrompt",
|
||||
template: "{{ .System }} {{ .Prompt }}",
|
||||
vars: []string{"prompt", "response", "system"},
|
||||
},
|
||||
{
|
||||
name: "PromptResponseSystem",
|
||||
template: "{{ .System }} {{ .Prompt }} {{ .Response }}",
|
||||
vars: []string{"prompt", "response", "system"},
|
||||
},
|
||||
{
|
||||
name: "ToolsBlock",
|
||||
template: "{{ with .Tools }}{{ . }}{{ end }} {{ .System }} {{ .Prompt }}",
|
||||
vars: []string{"prompt", "response", "system", "tools"},
|
||||
},
|
||||
{
|
||||
name: "MessagesRange",
|
||||
template: "{{ range .Messages }}{{ .Role }} {{ .Content }}{{ end }}",
|
||||
vars: []string{"content", "messages", "role"},
|
||||
},
|
||||
{
|
||||
name: "ToolResultConditional",
|
||||
template: "{{ range .Messages }}{{ if eq .Role \"tool\" }}Tool Result: {{ .ToolName }} {{ .Content }}{{ end }}{{ end }}",
|
||||
vars: []string{"content", "messages", "role", "toolname"},
|
||||
},
|
||||
{
|
||||
name: "MultilineSystemUserAssistant",
|
||||
template: `{{- range .Messages }}
|
||||
{{- if eq .Role "system" }}SYSTEM:
|
||||
{{- else if eq .Role "user" }}USER:
|
||||
{{- else if eq .Role "assistant" }}ASSISTANT:
|
||||
{{- else if eq .Role "tool" }}TOOL:
|
||||
{{- else if eq .Role "tool" }}TOOL:
|
||||
{{- end }} {{ .Content }}
|
||||
{{- end }}`, []string{"content", "messages", "role"}},
|
||||
{`{{- if .Messages }}
|
||||
{{- end }}`,
|
||||
vars: []string{"content", "messages", "role"},
|
||||
},
|
||||
{
|
||||
name: "ChatMLLike",
|
||||
template: `{{- if .Messages }}
|
||||
{{- range .Messages }}<|im_start|>{{ .Role }}
|
||||
{{ .Content }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
@@ -182,22 +213,60 @@ func TestParse(t *testing.T) {
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>
|
||||
{{- end -}}`, []string{"content", "messages", "prompt", "response", "role", "system"}},
|
||||
{{- end -}}`,
|
||||
vars: []string{"content", "messages", "prompt", "response", "role", "system"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
for _, tt := range validCases {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tmpl, err := Parse(tt.template)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("Parse returned unexpected error: %v", err)
|
||||
}
|
||||
|
||||
v, err := tmpl.Vars()
|
||||
gotVars, err := tmpl.Vars()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Fatalf("Vars returned unexpected error: %v", err)
|
||||
}
|
||||
if diff := cmp.Diff(v, tt.vars); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
|
||||
if diff := cmp.Diff(gotVars, tt.vars); diff != "" {
|
||||
t.Errorf("Vars mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseError(t *testing.T) {
|
||||
invalidCases := []struct {
|
||||
name string
|
||||
template string
|
||||
errorStr string
|
||||
}{
|
||||
{
|
||||
"TemplateNotClosed",
|
||||
"{{ .Prompt ",
|
||||
"unclosed action",
|
||||
},
|
||||
{
|
||||
"Template",
|
||||
`{{define "x"}}{{template "x"}}{{end}}{{template "x"}}`,
|
||||
"undefined template specified",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range invalidCases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := Parse(tt.template)
|
||||
if err == nil {
|
||||
t.Fatalf("expected Parse to return an error for an invalid template, got nil")
|
||||
}
|
||||
|
||||
if !strings.Contains(strings.ToLower(err.Error()), strings.ToLower(tt.errorStr)) {
|
||||
t.Errorf("unexpected error message.\n got: %q\n want substring (case‑insensitive): %q", err.Error(), tt.errorStr)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user