Compare commits
26 Commits
v0.12.6
...
nicole/emb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
075241f8cd | ||
|
|
15c7d30d9a | ||
|
|
9862317174 | ||
|
|
ec9eb28f4c | ||
|
|
1bdd816910 | ||
|
|
5d347f6d6f | ||
|
|
b97eb2b858 | ||
|
|
ad6f6a1d29 | ||
|
|
6723a40be6 | ||
|
|
3258a89b6e | ||
|
|
1c093e97af | ||
|
|
a8d9c2648e | ||
|
|
0334e67ffd | ||
|
|
e0ead1adee | ||
|
|
d515aed6c3 | ||
|
|
5fe7ba1b9b | ||
|
|
d2b63c19b3 | ||
|
|
94f110b35a | ||
|
|
5d22953ba7 | ||
|
|
d245dffed8 | ||
|
|
bc1a818fdc | ||
|
|
ba2253dc30 | ||
|
|
68e04c7ff8 | ||
|
|
270679932f | ||
|
|
65fb3ff49d | ||
|
|
e2a0b24435 |
@@ -30,7 +30,7 @@
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;87;89;90;90a;120",
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;52;60;61;70;75;80;86;89;90;90a;120",
|
||||
"CMAKE_CUDA_FLAGS": "-Wno-deprecated-gpu-targets -t 2"
|
||||
}
|
||||
},
|
||||
@@ -38,7 +38,7 @@
|
||||
"name": "CUDA 13",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_ARCHITECTURES": "75-virtual;80-virtual;86-virtual;87-virtual;89-virtual;90-virtual;90a-virtual;100-virtual;103-virtual;110-virtual;120-virtual;121-virtual",
|
||||
"CMAKE_CUDA_FLAGS": "-t 2"
|
||||
}
|
||||
},
|
||||
|
||||
@@ -461,6 +461,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AWS-Strands-With-Ollama](https://github.com/rapidarchitect/ollama_strands) - AWS Strands Agents with Ollama Examples
|
||||
- [ollama-multirun](https://github.com/attogram/ollama-multirun) - A bash shell script to run a single prompt against any or all of your locally installed ollama models, saving the output and performance statistics as easily navigable web pages. ([Demo](https://attogram.github.io/ai_test_zone/))
|
||||
- [ollama-bash-toolshed](https://github.com/attogram/ollama-bash-toolshed) - Bash scripts to chat with tool using models. Add new tools to your shed with ease. Runs on Ollama.
|
||||
- [VT Code](https://github.com/vinhnx/vtcode) - VT Code is a Rust-based terminal coding agent with semantic code intelligence via Tree-sitter. Ollama integration for running local/cloud models with configurable endpoints.
|
||||
|
||||
### Apple Vision Pro
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
)
|
||||
|
||||
@@ -339,13 +340,8 @@ func TestConvertAdapter(t *testing.T) {
|
||||
}
|
||||
|
||||
actual := generateResultsJSON(t, r, m.KV(), m.Tensors())
|
||||
|
||||
for _, k := range slices.Sorted(maps.Keys(c.Expected)) {
|
||||
if v, ok := actual[k]; !ok {
|
||||
t.Errorf("missing %s", k)
|
||||
} else if v != c.Expected[k] {
|
||||
t.Errorf("unexpected %s: want %s, got %s", k, c.Expected[k], v)
|
||||
}
|
||||
if diff := cmp.Diff(c.Expected, actual); diff != "" {
|
||||
t.Errorf("mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -2065,12 +2065,6 @@ power management:
|
||||
cpus := linuxCPUDetails(buf)
|
||||
|
||||
slog.Info("example", "scenario", k, "cpus", cpus)
|
||||
si := SystemInfo{
|
||||
System: CPUInfo{
|
||||
CPUs: cpus,
|
||||
},
|
||||
}
|
||||
threadCount := si.GetOptimalThreadCount()
|
||||
if len(v.expCPUs) != len(cpus) {
|
||||
t.Fatalf("incorrect number of sockets: expected:%v got:%v", v.expCPUs, cpus)
|
||||
}
|
||||
@@ -2085,10 +2079,6 @@ power management:
|
||||
t.Fatalf("incorrect number of threads: expected:%v got:%v", v.expCPUs[i], c)
|
||||
}
|
||||
}
|
||||
|
||||
if threadCount != v.expThreadCount {
|
||||
t.Fatalf("incorrect thread count expected:%d got:%d", v.expThreadCount, threadCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
164
discover/gpu.go
164
discover/gpu.go
@@ -1,16 +1,13 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -18,159 +15,28 @@ import (
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
func GetCPUInfo() GpuInfo {
|
||||
mem, err := GetCPUMem()
|
||||
// GetSystemInfo returns the last cached state of the GPUs on the system
|
||||
func GetSystemInfo() ml.SystemInfo {
|
||||
memInfo, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
}
|
||||
|
||||
return GpuInfo{
|
||||
memInfo: mem,
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "cpu",
|
||||
ID: "0",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetGPUInfo(ctx context.Context, runners []FilteredRunnerDiscovery) GpuInfoList {
|
||||
devs := GPUDevices(ctx, runners)
|
||||
return devInfoToInfoList(devs)
|
||||
}
|
||||
|
||||
func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
|
||||
resp := []GpuInfo{}
|
||||
// Our current packaging model places ggml-hip in the main directory
|
||||
// but keeps rocm in an isolated directory. We have to add it to
|
||||
// the [LD_LIBRARY_]PATH so ggml-hip will load properly
|
||||
rocmDir := filepath.Join(LibOllamaPath, "rocm")
|
||||
if _, err := os.Stat(rocmDir); err != nil {
|
||||
rocmDir = ""
|
||||
var threadCount int
|
||||
cpus := GetCPUDetails()
|
||||
for _, c := range cpus {
|
||||
threadCount += c.CoreCount - c.EfficiencyCoreCount
|
||||
}
|
||||
|
||||
for _, dev := range devs {
|
||||
info := GpuInfo{
|
||||
DeviceID: dev.DeviceID,
|
||||
filterID: dev.FilteredID,
|
||||
Name: dev.Description,
|
||||
memInfo: memInfo{
|
||||
TotalMemory: dev.TotalMemory,
|
||||
FreeMemory: dev.FreeMemory,
|
||||
},
|
||||
// TODO can we avoid variant
|
||||
DependencyPath: dev.LibraryPath,
|
||||
DriverMajor: dev.DriverMajor,
|
||||
DriverMinor: dev.DriverMinor,
|
||||
ComputeMajor: dev.ComputeMajor,
|
||||
ComputeMinor: dev.ComputeMinor,
|
||||
}
|
||||
if dev.Library == "CUDA" || dev.Library == "ROCm" {
|
||||
info.MinimumMemory = 457 * format.MebiByte
|
||||
}
|
||||
if dev.Library == "ROCm" && rocmDir != "" {
|
||||
info.DependencyPath = append(info.DependencyPath, rocmDir)
|
||||
}
|
||||
// TODO any special processing of Vulkan devices?
|
||||
resp = append(resp, info)
|
||||
}
|
||||
if len(resp) == 0 {
|
||||
mem, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
}
|
||||
|
||||
resp = append(resp, GpuInfo{
|
||||
memInfo: mem,
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "cpu",
|
||||
ID: "0",
|
||||
},
|
||||
})
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variable
|
||||
//
|
||||
// If different libraries are detected, the first one is what we use
|
||||
func (l GpuInfoList) GetVisibleDevicesEnv() []string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
res := []string{}
|
||||
envVar := rocmGetVisibleDevicesEnv(l)
|
||||
if envVar != "" {
|
||||
res = append(res, envVar)
|
||||
}
|
||||
envVar = vkGetVisibleDevicesEnv(l)
|
||||
if envVar != "" {
|
||||
res = append(res, envVar)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "ROCm" {
|
||||
continue
|
||||
}
|
||||
// If the devices requires a numeric ID, for filtering purposes, we use the unfiltered ID number
|
||||
if info.filterID != "" {
|
||||
ids = append(ids, info.filterID)
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
envVar := "ROCR_VISIBLE_DEVICES="
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES="
|
||||
}
|
||||
// There are 3 potential env vars to use to select GPUs.
|
||||
// ROCR_VISIBLE_DEVICES supports UUID or numeric but does not work on Windows
|
||||
// HIP_VISIBLE_DEVICES supports numeric IDs only
|
||||
// GPU_DEVICE_ORDINAL supports numeric IDs only
|
||||
return envVar + strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string {
|
||||
ids := []string{}
|
||||
for _, info := range gpuInfo {
|
||||
if info.Library != "Vulkan" {
|
||||
continue
|
||||
}
|
||||
if info.filterID != "" {
|
||||
ids = append(ids, info.filterID)
|
||||
} else {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
}
|
||||
if len(ids) == 0 {
|
||||
return ""
|
||||
}
|
||||
envVar := "GGML_VK_VISIBLE_DEVICES="
|
||||
return envVar + strings.Join(ids, ",")
|
||||
}
|
||||
|
||||
// GetSystemInfo returns the last cached state of the GPUs on the system
|
||||
func GetSystemInfo() SystemInfo {
|
||||
deviceMu.Lock()
|
||||
defer deviceMu.Unlock()
|
||||
gpus := devInfoToInfoList(devices)
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
gpus = []GpuInfo{}
|
||||
if threadCount == 0 {
|
||||
// Fall back to Go's num CPU
|
||||
threadCount = runtime.NumCPU()
|
||||
}
|
||||
|
||||
return SystemInfo{
|
||||
System: CPUInfo{
|
||||
CPUs: GetCPUDetails(),
|
||||
GpuInfo: GetCPUInfo(),
|
||||
},
|
||||
GPUs: gpus,
|
||||
return ml.SystemInfo{
|
||||
ThreadCount: threadCount,
|
||||
TotalMemory: memInfo.TotalMemory,
|
||||
FreeMemory: memInfo.FreeMemory,
|
||||
FreeSwap: memInfo.FreeSwap,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,13 +4,8 @@ package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -23,6 +18,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
@@ -36,7 +32,7 @@ var (
|
||||
bootstrapped bool
|
||||
)
|
||||
|
||||
func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
deviceMu.Lock()
|
||||
defer deviceMu.Unlock()
|
||||
startDiscovery := time.Now()
|
||||
@@ -88,6 +84,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
// times concurrently leading to memory contention
|
||||
// TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs
|
||||
for dir := range libDirs {
|
||||
bootstrapTimeout := 30 * time.Second
|
||||
var dirs []string
|
||||
if dir != "" {
|
||||
if requested != "" && filepath.Base(dir) != requested {
|
||||
@@ -102,11 +99,16 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
} else {
|
||||
dirs = []string{LibOllamaPath, dir}
|
||||
}
|
||||
|
||||
// ROCm can take a long time on some systems, so give it more time before giving up
|
||||
if dir != "" && strings.Contains(filepath.Base(dir), "rocm") {
|
||||
bootstrapTimeout = 60 * time.Second
|
||||
}
|
||||
// Typically bootstrapping takes < 1s, but on some systems, with devices
|
||||
// in low power/idle mode, initialization can take multiple seconds. We
|
||||
// set a long timeout just for bootstrap discovery to reduce the chance
|
||||
// of giving up too quickly
|
||||
ctx1stPass, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
ctx1stPass, cancel := context.WithTimeout(ctx, bootstrapTimeout)
|
||||
defer cancel()
|
||||
|
||||
// For this pass, we retain duplicates in case any are incompatible with some libraries
|
||||
@@ -148,9 +150,9 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
slog.Error("Unknown Library:" + devices[i].Library)
|
||||
}
|
||||
|
||||
extraEnvs := []string{
|
||||
"GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs
|
||||
envVar + "=" + id, // Filter to just this one GPU
|
||||
extraEnvs := map[string]string{
|
||||
"GGML_CUDA_INIT": "1", // force deep initialization to trigger crash on unsupported GPUs
|
||||
envVar: id, // Filter to just this one GPU
|
||||
}
|
||||
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
|
||||
needsDelete[i] = true
|
||||
@@ -443,100 +445,35 @@ func (r *bootstrapRunner) HasExited() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []string) []ml.DeviceInfo {
|
||||
// TODO DRY out with llm/server.go
|
||||
slog.Debug("spawning runner with", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map[string]string) []ml.DeviceInfo {
|
||||
var out io.Writer
|
||||
if envconfig.LogLevel() == logutil.LevelTrace {
|
||||
out = os.Stderr
|
||||
}
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
slog.Debug("bootstrap discovery took", "duration", time.Since(start), "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "extra_envs", extraEnvs)
|
||||
}()
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed, using random port")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
params := []string{"runner", "--ollama-engine", "--port", strconv.Itoa(port)}
|
||||
var pathEnv string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
pathEnv = "PATH"
|
||||
case "darwin":
|
||||
pathEnv = "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
pathEnv = "LD_LIBRARY_PATH"
|
||||
}
|
||||
libraryPaths := append([]string{LibOllamaPath}, ollamaLibDirs...)
|
||||
if rocmDir != "" {
|
||||
libraryPaths = append(libraryPaths, rocmDir)
|
||||
}
|
||||
// Note: we always put our dependency paths first
|
||||
// since these are the exact version we compiled/linked against
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
|
||||
}
|
||||
|
||||
cmd := exec.Command(exe, params...)
|
||||
cmd.Env = os.Environ()
|
||||
if envconfig.LogLevel() == logutil.LevelTrace {
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
}
|
||||
|
||||
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
extraDone := make([]bool, len(extraEnvs))
|
||||
for i := range cmd.Env {
|
||||
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else {
|
||||
for j := range extraEnvs {
|
||||
if extraDone[j] {
|
||||
continue
|
||||
}
|
||||
extra := strings.SplitN(extraEnvs[j], "=", 2)
|
||||
if cmp[0] == extra[0] {
|
||||
cmd.Env[i] = extraEnvs[j]
|
||||
extraDone[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if pathNeeded {
|
||||
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i := range extraDone {
|
||||
if !extraDone[i] {
|
||||
cmd.Env = append(cmd.Env, extraEnvs[i])
|
||||
}
|
||||
}
|
||||
logutil.Trace("starting runner for device discovery", "env", cmd.Env, "cmd", cmd)
|
||||
if err := cmd.Start(); err != nil {
|
||||
slog.Warn("unable to start discovery subprocess", "cmd", cmd, "error", err)
|
||||
logutil.Trace("starting runner for device discovery", "libDirs", ollamaLibDirs, "extraEnvs", extraEnvs)
|
||||
cmd, port, err := llm.StartRunner(
|
||||
true, // ollama engine
|
||||
"", // no model
|
||||
ollamaLibDirs,
|
||||
out,
|
||||
extraEnvs,
|
||||
)
|
||||
if err != nil {
|
||||
slog.Debug("failed to start runner to discovery GPUs", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
go func() {
|
||||
cmd.Wait() // exit status ignored
|
||||
}()
|
||||
|
||||
defer cmd.Process.Kill()
|
||||
devices, err := GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, &bootstrapRunner{port: port, cmd: cmd})
|
||||
if err != nil {
|
||||
if cmd.ProcessState != nil && cmd.ProcessState.ExitCode() >= 0 {
|
||||
// Expected during bootstrapping while we filter out unsupported AMD GPUs
|
||||
@@ -549,52 +486,3 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||
|
||||
return devices
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]ml.DeviceInfo, error) {
|
||||
var moreDevices []ml.DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -17,50 +15,6 @@ type memInfo struct {
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"` // TODO split this out for system only
|
||||
}
|
||||
|
||||
// Beginning of an `ollama info` command
|
||||
type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
ml.DeviceID
|
||||
memInfo
|
||||
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath []string `json:"lib_path,omitempty"`
|
||||
|
||||
// Set to true if we can NOT reliably discover FreeMemory. A value of true indicates
|
||||
// the FreeMemory is best effort, and may over or under report actual memory usage
|
||||
// False indicates FreeMemory can generally be trusted on this GPU
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
|
||||
ComputeMinor int `json:"compute_minor"`
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
DriverMinor int `json:"driver_minor,omitempty"`
|
||||
|
||||
// TODO other performance capability info to help in scheduling decisions
|
||||
}
|
||||
|
||||
func (gpu GpuInfo) RunnerName() string {
|
||||
if gpu.Variant != "" {
|
||||
return gpu.Library + "_" + gpu.Variant
|
||||
}
|
||||
return gpu.Library
|
||||
}
|
||||
|
||||
type CPUInfo struct {
|
||||
GpuInfo
|
||||
CPUs []CPU
|
||||
}
|
||||
|
||||
// CPU type represents a CPU Package occupying a socket
|
||||
type CPU struct {
|
||||
ID string `cpuinfo:"processor"`
|
||||
@@ -71,32 +25,6 @@ type CPU struct {
|
||||
ThreadCount int
|
||||
}
|
||||
|
||||
type GpuInfoList []GpuInfo
|
||||
|
||||
func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||
resp := []GpuInfoList{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []GpuInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LogDetails(devices []ml.DeviceInfo) {
|
||||
for _, dev := range devices {
|
||||
var libs []string
|
||||
@@ -141,74 +69,3 @@ func LogDetails(devices []ml.DeviceInfo) {
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by Free Space
|
||||
type ByFreeMemory []GpuInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||
|
||||
type SystemInfo struct {
|
||||
System CPUInfo `json:"system"`
|
||||
GPUs []GpuInfo `json:"gpus"`
|
||||
}
|
||||
|
||||
// Return the optimal number of threads to use for inference
|
||||
func (si SystemInfo) GetOptimalThreadCount() int {
|
||||
if len(si.System.CPUs) == 0 {
|
||||
// Fall back to Go's num CPU
|
||||
return runtime.NumCPU()
|
||||
}
|
||||
|
||||
coreCount := 0
|
||||
for _, c := range si.System.CPUs {
|
||||
coreCount += c.CoreCount - c.EfficiencyCoreCount
|
||||
}
|
||||
|
||||
return coreCount
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
|
||||
gpu.Library == "ROCm" ||
|
||||
gpu.Library == "Vulkan"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []ml.DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []ml.DeviceID
|
||||
}
|
||||
|
||||
@@ -9,15 +9,20 @@ Check your compute compatibility to see if your card is supported:
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
|
||||
| | NVIDIA Professioal | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
|
||||
| 9.0 | NVIDIA | `H200` `H100` |
|
||||
| 11.0 | Jetson | `T4000` `T5000` (Requires driver 580 or newer) |
|
||||
| 10.3 | NVIDIA Professioal | `B300` `GB300` (Requires driver 580 or newer) |
|
||||
| 10.0 | NVIDIA Professioal | `B200` `GB200` (Requires driver 580 or newer) |
|
||||
| 9.0 | NVIDIA | `H200` `H100` `GH200` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
| 8.7 | Jetson | `Orin Nano` `Orin NX` `AGX Orin` |
|
||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
|
||||
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
|
||||
| 8.0 | NVIDIA | `A100` `A30` |
|
||||
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
|
||||
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
|
||||
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
|
||||
| 7.2 | Jetson | `Xavier NX` `AGX Xavier` (Jetpack 5) |
|
||||
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
|
||||
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
|
||||
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
|
||||
|
||||
@@ -509,7 +509,10 @@ func writeGGUFArray[S ~[]E, E any](w io.Writer, t uint32, s S) error {
|
||||
}
|
||||
|
||||
func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
arch := kv.String("general.architecture")
|
||||
if arch == "" {
|
||||
return fmt.Errorf("architecture not set")
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
return err
|
||||
@@ -528,7 +531,7 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
}
|
||||
|
||||
for _, key := range slices.Sorted(maps.Keys(kv)) {
|
||||
if err := ggufWriteKV(f, key, kv[key]); err != nil {
|
||||
if err := ggufWriteKV(f, arch, key, kv[key]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -543,6 +546,8 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
},
|
||||
)
|
||||
|
||||
alignment := kv.Uint("general.alignment", 32)
|
||||
|
||||
var s uint64
|
||||
for i := range ts {
|
||||
ts[i].Offset = s
|
||||
@@ -574,7 +579,14 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
|
||||
return g.Wait()
|
||||
}
|
||||
|
||||
func ggufWriteKV(ws io.WriteSeeker, k string, v any) error {
|
||||
func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
|
||||
if !strings.HasPrefix(k, arch+".") &&
|
||||
!strings.HasPrefix(k, "general.") &&
|
||||
!strings.HasPrefix(k, "adapter.") &&
|
||||
!strings.HasPrefix(k, "tokenizer.") {
|
||||
k = arch + "." + k
|
||||
}
|
||||
|
||||
slog.Debug(k, "type", fmt.Sprintf("%T", v))
|
||||
if err := binary.Write(ws, binary.LittleEndian, uint64(len(k))); err != nil {
|
||||
return err
|
||||
|
||||
@@ -39,7 +39,12 @@ func TestWriteGGUF(t *testing.T) {
|
||||
defer w.Close()
|
||||
|
||||
if err := WriteGGUF(w, KV{
|
||||
"general.alignment": uint32(16),
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"test.key": "value",
|
||||
"attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ts); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -56,14 +61,19 @@ func TestWriteGGUF(t *testing.T) {
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(KV{
|
||||
"general.architecture": "test",
|
||||
"general.alignment": uint32(16),
|
||||
"general.parameter_count": uint64(54),
|
||||
"test.key": "value",
|
||||
"test.attention.key": "value2",
|
||||
"tokenizer.key": "value3",
|
||||
"adapter.key": "value4",
|
||||
}, ff.KV()); diff != "" {
|
||||
t.Errorf("Mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(Tensors{
|
||||
Offset: 592,
|
||||
Offset: 800,
|
||||
items: []*Tensor{
|
||||
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
|
||||
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},
|
||||
|
||||
@@ -109,6 +109,8 @@ func TestMultiModelStress(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Make sure all the models are pulled before we get started
|
||||
for _, model := range chosenModels {
|
||||
@@ -147,6 +149,8 @@ chooseModels:
|
||||
for _, m := range models.Models {
|
||||
if m.SizeVRAM == 0 {
|
||||
slog.Info("model running on CPU", "name", m.Name, "target", targetLoadCount, "chosen", chosenModels[:targetLoadCount])
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
break chooseModels
|
||||
}
|
||||
}
|
||||
@@ -172,10 +176,7 @@ chooseModels:
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
|
||||
DoChat(ctx, t, client, reqs[k], resps[k],
|
||||
120*time.Second, // Be extra patient for the model to load initially
|
||||
10*time.Second, // Once results start streaming, fail if they stall
|
||||
)
|
||||
DoChat(ctx, t, client, reqs[k], resps[k], initialTimeout, streamTimeout)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
@@ -78,7 +78,7 @@ func TestContextExhaustion(t *testing.T) {
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := "gpt-oss:20b"
|
||||
modelName := "gpt-oss:20b"
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
@@ -88,15 +88,23 @@ func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
slog.Info("loading", "model", modelName)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
t.Fatalf("failed to load model %s: %s", modelName, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, modelName)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -105,7 +113,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
req[k].Model = modelName
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
@@ -114,7 +122,7 @@ func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
c := DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
c := DoGenerate(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
|
||||
req[k].Context = c
|
||||
req[k].Prompt = "tell me more!"
|
||||
}
|
||||
@@ -165,7 +173,7 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelChatWithHistory(t *testing.T) {
|
||||
modelOverride := "gpt-oss:20b"
|
||||
modelName := "gpt-oss:20b"
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
@@ -175,15 +183,23 @@ func TestParallelChatWithHistory(t *testing.T) {
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 20 * time.Second
|
||||
|
||||
// Get the server running (if applicable) warm the model up with a single initial empty request
|
||||
slog.Info("loading", "model", modelOverride)
|
||||
slog.Info("loading", "model", modelName)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: modelOverride, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
&api.GenerateRequest{Model: modelName, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", modelOverride, err)
|
||||
t.Fatalf("failed to load model %s: %s", modelName, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, modelName)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 30 * time.Second
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
@@ -192,7 +208,7 @@ func TestParallelChatWithHistory(t *testing.T) {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
k := i % len(req)
|
||||
req[k].Model = modelOverride
|
||||
req[k].Model = modelName
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
@@ -201,7 +217,7 @@ func TestParallelChatWithHistory(t *testing.T) {
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
assistant := DoChat(ctx, t, client, req[k], resp[k], initialTimeout, streamTimeout)
|
||||
if assistant == nil {
|
||||
t.Fatalf("didn't get an assistant response for context")
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"math"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -258,6 +260,19 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "boundary truncation",
|
||||
request: api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue? Why is the sky blue? hi there my",
|
||||
Options: map[string]any{"num_ctx": 16},
|
||||
},
|
||||
check: func(res *api.EmbedResponse, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, req := range cases {
|
||||
@@ -286,3 +301,197 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
|
||||
|
||||
return client.Embed(ctx, &req)
|
||||
}
|
||||
|
||||
func TestEmbedTruncation(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
t.Run("single input token count", func(t *testing.T) {
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: "why is the sky blue?",
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch parallel token counting", func(t *testing.T) {
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"cat", "dog and mouse", "bird"},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount <= 0 {
|
||||
t.Fatalf("expected positive token count, got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncation single input", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: longInput,
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 50},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 50 {
|
||||
t.Fatalf("expected tokens <= 50 after truncation, got %d", res.PromptEvalCount)
|
||||
}
|
||||
|
||||
if res.PromptEvalCount == 0 {
|
||||
t.Fatal("expected non-zero token count after truncation")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("truncation batch", func(t *testing.T) {
|
||||
truncTrue := true
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
|
||||
Truncate: &truncTrue,
|
||||
Options: map[string]any{"num_ctx": 30},
|
||||
}
|
||||
|
||||
res, err := embedTestHelper(ctx, client, t, req)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(res.Embeddings) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
|
||||
}
|
||||
|
||||
if res.PromptEvalCount > 90 {
|
||||
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("runner token count accuracy", func(t *testing.T) {
|
||||
baseline := api.EmbedRequest{Model: "all-minilm", Input: "test"}
|
||||
baseRes, err := embedTestHelper(ctx, client, t, baseline)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
batch := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{"test", "test", "test"},
|
||||
}
|
||||
batchRes, err := embedTestHelper(ctx, client, t, batch)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expectedCount := baseRes.PromptEvalCount * 3
|
||||
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
|
||||
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
|
||||
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// TestEmbedStatusCode tests that errors from the embedding endpoint
|
||||
// properly preserve their HTTP status codes when returned to the client.
|
||||
// This test specifically checks the error handling path in EmbedHandler
|
||||
// where api.StatusError errors should maintain their original status code.
|
||||
func TestEmbedStatusCode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
// Pull the model if needed
|
||||
if err := PullIfMissing(ctx, client, "all-minilm"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
longInput := strings.Repeat("word ", 100)
|
||||
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: longInput,
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(ctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when truncate=false with long input")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error (likely 400 Bad Request)
|
||||
// not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
|
||||
// Verify the error message is meaningful
|
||||
if !strings.Contains(err.Error(), "context length") {
|
||||
t.Errorf("expected error message to mention context length, got: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("batch truncation error status code", func(t *testing.T) {
|
||||
truncFalse := false
|
||||
req := api.EmbedRequest{
|
||||
Model: "all-minilm",
|
||||
Input: []string{
|
||||
"short input",
|
||||
strings.Repeat("very long input ", 100),
|
||||
"another short input",
|
||||
},
|
||||
Truncate: &truncFalse,
|
||||
Options: map[string]any{"num_ctx": 10},
|
||||
}
|
||||
|
||||
_, err := embedTestHelper(ctx, client, t, req)
|
||||
if err == nil {
|
||||
t.Fatal("expected error when one input exceeds context with truncate=false")
|
||||
}
|
||||
|
||||
// Check that it's a StatusError with the correct status code
|
||||
var statusErr api.StatusError
|
||||
if !errors.As(err, &statusErr) {
|
||||
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
|
||||
}
|
||||
|
||||
// The error should be a 4xx client error, not a 500 Internal Server Error
|
||||
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
|
||||
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -65,6 +65,23 @@ func TestModelsChat(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
initialTimeout := 120 * time.Second
|
||||
streamTimeout := 30 * time.Second
|
||||
slog.Info("loading", "model", model)
|
||||
err := client.Generate(ctx,
|
||||
&api.GenerateRequest{Model: model, KeepAlive: &api.Duration{Duration: 10 * time.Second}},
|
||||
func(response api.GenerateResponse) error { return nil },
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", model, err)
|
||||
}
|
||||
gpuPercent := getGPUPercent(ctx, t, client, model)
|
||||
if gpuPercent < 80 {
|
||||
slog.Warn("Low GPU percentage - increasing timeouts", "percent", gpuPercent)
|
||||
initialTimeout = 240 * time.Second
|
||||
streamTimeout = 40 * time.Second
|
||||
}
|
||||
|
||||
// TODO - fiddle with context size
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
@@ -80,7 +97,7 @@ func TestModelsChat(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, blueSkyExpected, initialTimeout, streamTimeout)
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
})
|
||||
|
||||
@@ -743,6 +743,13 @@ func skipUnderMinVRAM(t *testing.T, gb uint64) {
|
||||
|
||||
// Skip if the target model isn't X% GPU loaded to avoid excessive runtime
|
||||
func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, model string, minPercent int) {
|
||||
gpuPercent := getGPUPercent(ctx, t, client, model)
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
}
|
||||
|
||||
func getGPUPercent(ctx context.Context, t *testing.T, client *api.Client, model string) int {
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to list running models: %s", err)
|
||||
@@ -772,12 +779,10 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
|
||||
cpuPercent := math.Round(float64(sizeCPU) / float64(m.Size) * 110)
|
||||
gpuPercent = int(100 - cpuPercent)
|
||||
}
|
||||
if gpuPercent < minPercent {
|
||||
t.Skip(fmt.Sprintf("test requires minimum %d%% GPU load, but model %s only has %d%%", minPercent, model, gpuPercent))
|
||||
}
|
||||
return
|
||||
return gpuPercent
|
||||
}
|
||||
t.Skip(fmt.Sprintf("model %s not loaded - actually loaded: %v", model, loaded))
|
||||
t.Fatalf("model %s not loaded - actually loaded: %v", model, loaded)
|
||||
return 0
|
||||
}
|
||||
|
||||
func getTimeouts(t *testing.T) (soft time.Duration, hard time.Duration) {
|
||||
|
||||
@@ -40,11 +40,6 @@ type Causal struct {
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// curReserve indicates that this forward pass is only for
|
||||
// memory reservation and we should not update our metadata
|
||||
// based on it.
|
||||
curReserve bool
|
||||
|
||||
// the active layer for Get and Put
|
||||
curLayer int
|
||||
|
||||
@@ -206,13 +201,12 @@ func (c *Causal) Close() {
|
||||
}
|
||||
|
||||
func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) error {
|
||||
c.curReserve = reserve
|
||||
c.curBatchSize = len(batch.Positions)
|
||||
c.curSequences = batch.Sequences
|
||||
c.curPositions = batch.Positions
|
||||
c.opts.Except = nil
|
||||
|
||||
if !c.curReserve {
|
||||
if !reserve {
|
||||
c.updateSlidingWindow()
|
||||
|
||||
var err error
|
||||
@@ -379,10 +373,6 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
|
||||
if c.curReserve {
|
||||
return ctx.Input().Empty(c.config.MaskDType, length, batchSize)
|
||||
}
|
||||
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
|
||||
@@ -6,20 +6,20 @@ Subject: [PATCH] GPU discovery enhancements
|
||||
Expose more information about the devices through backend props, and leverage
|
||||
management libraries for more accurate VRAM usage reporting if available.
|
||||
---
|
||||
ggml/include/ggml-backend.h | 9 +
|
||||
ggml/include/ggml-backend.h | 11 +
|
||||
ggml/src/CMakeLists.txt | 2 +
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 72 +++++
|
||||
ggml/src/ggml-cuda/ggml-cuda.cu | 74 +++++
|
||||
ggml/src/ggml-cuda/vendors/hip.h | 3 +
|
||||
ggml/src/ggml-impl.h | 8 +
|
||||
ggml/src/ggml-metal/ggml-metal.cpp | 2 +
|
||||
ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++
|
||||
ggml/src/mem_nvml.cpp | 209 ++++++++++++++
|
||||
8 files changed, 754 insertions(+)
|
||||
8 files changed, 758 insertions(+)
|
||||
create mode 100644 ggml/src/mem_hip.cpp
|
||||
create mode 100644 ggml/src/mem_nvml.cpp
|
||||
|
||||
diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h
|
||||
index ba181d09..09ff75f9 100644
|
||||
index ba181d09d..094fc3c82 100644
|
||||
--- a/ggml/include/ggml-backend.h
|
||||
+++ b/ggml/include/ggml-backend.h
|
||||
@@ -169,6 +169,17 @@ extern "C" {
|
||||
@@ -41,7 +41,7 @@ index ba181d09..09ff75f9 100644
|
||||
|
||||
GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device);
|
||||
diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt
|
||||
index 0609c650..aefe43bd 100644
|
||||
index 0609c6503..aefe43bdd 100644
|
||||
--- a/ggml/src/CMakeLists.txt
|
||||
+++ b/ggml/src/CMakeLists.txt
|
||||
@@ -209,6 +209,8 @@ add_library(ggml-base
|
||||
@@ -54,7 +54,7 @@ index 0609c650..aefe43bd 100644
|
||||
|
||||
target_include_directories(ggml-base PRIVATE .)
|
||||
diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
index 87c6c34a..6a278b5e 100644
|
||||
index 87c6c34a4..816597d2f 100644
|
||||
--- a/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
+++ b/ggml/src/ggml-cuda/ggml-cuda.cu
|
||||
@@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
|
||||
@@ -161,21 +161,23 @@ index 87c6c34a..6a278b5e 100644
|
||||
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
|
||||
#ifdef GGML_CUDA_NO_PEER_COPY
|
||||
bool events = false;
|
||||
@@ -4087,6 +4149,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
@@ -4087,6 +4149,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (!initialized) {
|
||||
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
|
||||
+ int driverVersion = 0;
|
||||
+ CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
|
||||
|
||||
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
|
||||
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
|
||||
@@ -4102,6 +4166,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
@@ -4102,6 +4165,17 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
|
||||
dev_ctx->pci_bus_id = pci_bus_id;
|
||||
|
||||
+ dev_ctx->major = prop.major;
|
||||
+ dev_ctx->minor = prop.minor;
|
||||
+ if (driverVersion == 0) {
|
||||
+ CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
|
||||
+ }
|
||||
+ dev_ctx->driver_major = driverVersion / 1000;
|
||||
+ dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
|
||||
+ dev_ctx->integrated = prop.integrated;
|
||||
@@ -186,7 +188,7 @@ index 87c6c34a..6a278b5e 100644
|
||||
/* .iface = */ ggml_backend_cuda_device_interface,
|
||||
/* .reg = */ ®,
|
||||
diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
index 1f06be80..2f9ef2dc 100644
|
||||
index 1f06be80e..2f9ef2dc0 100644
|
||||
--- a/ggml/src/ggml-cuda/vendors/hip.h
|
||||
+++ b/ggml/src/ggml-cuda/vendors/hip.h
|
||||
@@ -5,6 +5,8 @@
|
||||
@@ -207,7 +209,7 @@ index 1f06be80..2f9ef2dc 100644
|
||||
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
|
||||
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
|
||||
diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h
|
||||
index d0fb3bcc..80597b6e 100644
|
||||
index d0fb3bcca..80597b6ea 100644
|
||||
--- a/ggml/src/ggml-impl.h
|
||||
+++ b/ggml/src/ggml-impl.h
|
||||
@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx
|
||||
@@ -226,7 +228,7 @@ index d0fb3bcc..80597b6e 100644
|
||||
}
|
||||
#endif
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
index f2ff9f32..f356e4a0 100644
|
||||
index f2ff9f322..f356e4a0a 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.cpp
|
||||
@@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen
|
||||
@@ -247,7 +249,7 @@ index f2ff9f32..f356e4a0 100644
|
||||
/* .host_buffer = */ false,
|
||||
diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp
|
||||
new file mode 100644
|
||||
index 00000000..8ef19b8c
|
||||
index 000000000..8ef19b8cf
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_hip.cpp
|
||||
@@ -0,0 +1,449 @@
|
||||
@@ -703,7 +705,7 @@ index 00000000..8ef19b8c
|
||||
\ No newline at end of file
|
||||
diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp
|
||||
new file mode 100644
|
||||
index 00000000..c9073cef
|
||||
index 000000000..c9073cef0
|
||||
--- /dev/null
|
||||
+++ b/ggml/src/mem_nvml.cpp
|
||||
@@ -0,0 +1,209 @@
|
||||
|
||||
32
llama/patches/0031-report-LoadLibrary-failures.patch
Normal file
32
llama/patches/0031-report-LoadLibrary-failures.patch
Normal file
@@ -0,0 +1,32 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Daniel Hiltgen <daniel@ollama.com>
|
||||
Date: Fri, 17 Oct 2025 14:17:00 -0700
|
||||
Subject: [PATCH] report LoadLibrary failures
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 12 ++++++++++++
|
||||
1 file changed, 12 insertions(+)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index f794d9cfa..3a855ab2e 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
+ if (!handle) {
|
||||
+ DWORD error_code = GetLastError();
|
||||
+ std::string msg;
|
||||
+ LPSTR lpMsgBuf = NULL;
|
||||
+ DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
+ NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
|
||||
+ if (bufLen) {
|
||||
+ msg = lpMsgBuf;
|
||||
+ LocalFree(lpMsgBuf);
|
||||
+ GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
|
||||
+ }
|
||||
+ }
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
@@ -4,27 +4,28 @@ import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits
|
||||
// The list of GPUs returned will always be the same brand (library)
|
||||
// If the model can not be fit fully within the available GPU(s) nil is returned
|
||||
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList {
|
||||
for _, gl := range gpus.ByLibrary() {
|
||||
sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...)
|
||||
func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
||||
for _, gl := range ml.ByLibrary(gpus) {
|
||||
sgl := append(make([]ml.DeviceInfo, 0, len(gl)), gl...)
|
||||
|
||||
// TODO - potentially sort by performance capability, existing models loaded, etc.
|
||||
// TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them
|
||||
// Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups
|
||||
sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl)))
|
||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(sgl)))
|
||||
|
||||
if !envconfig.SchedSpread() {
|
||||
// Try to pack into as few GPUs as possible, starting from 1 GPU
|
||||
@@ -63,8 +64,8 @@ func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []strin
|
||||
}
|
||||
|
||||
// If multiple Libraries are detected, pick the Library which loads the most layers for the model
|
||||
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList {
|
||||
byLibrary := gpus.ByLibrary()
|
||||
func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus []ml.DeviceInfo, numParallel int) []ml.DeviceInfo {
|
||||
byLibrary := ml.ByLibrary(gpus)
|
||||
if len(byLibrary) <= 1 {
|
||||
return gpus
|
||||
}
|
||||
@@ -81,10 +82,10 @@ func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []s
|
||||
}
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
func predictServerFit(allGpus []ml.DeviceInfo, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) {
|
||||
// Split up the GPUs by type and try them
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
for _, gpus := range ml.ByLibrary(allGpus) {
|
||||
var layerCount int
|
||||
estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
@@ -97,14 +98,23 @@ func predictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, proj
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" && estimate.TotalSize <= gpus[0].FreeMemory {
|
||||
return true, estimatedVRAM
|
||||
}
|
||||
}
|
||||
return false, estimatedVRAM
|
||||
}
|
||||
|
||||
func verifyCPUFit(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, systemInfo ml.SystemInfo, numParallel int) bool {
|
||||
estimate := estimateGPULayers(nil, f, projectors, opts, numParallel)
|
||||
if estimate.TotalSize > systemInfo.FreeMemory {
|
||||
return false
|
||||
}
|
||||
slog.Info("new model will fit in available system memory for CPU inference, loading",
|
||||
"model", modelPath,
|
||||
"parallel", numParallel,
|
||||
"required", format.HumanBytes2(estimate.TotalSize),
|
||||
)
|
||||
return true
|
||||
}
|
||||
|
||||
type MemoryEstimate struct {
|
||||
// How many layers we predict we can load
|
||||
Layers int
|
||||
@@ -141,7 +151,7 @@ type MemoryEstimate struct {
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
func estimateGPULayers(gpus []ml.DeviceInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
@@ -175,10 +185,17 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
|
||||
overhead := envconfig.GpuOverhead()
|
||||
availableList := make([]string, len(gpus))
|
||||
libraries := []string{}
|
||||
for i, gpu := range gpus {
|
||||
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
|
||||
if !slices.Contains(libraries, gpu.Library) {
|
||||
libraries = append(libraries, gpu.Library)
|
||||
}
|
||||
}
|
||||
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
|
||||
if len(libraries) == 0 {
|
||||
libraries = []string{"cpu"}
|
||||
}
|
||||
slog.Debug("evaluating", "library", strings.Join(libraries, ","), "gpu_count", len(gpus), "available", availableList)
|
||||
|
||||
for _, projector := range projectors {
|
||||
llamaEngineProjectorWeights += projectorMemoryRequirements(projector)
|
||||
@@ -196,7 +213,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
|
||||
useFlashAttention := envconfig.FlashAttention(f.FlashAttention()) &&
|
||||
(discover.GpuInfoList)(gpus).FlashAttentionSupported() &&
|
||||
ml.FlashAttentionSupported(gpus) &&
|
||||
f.SupportsFlashAttention()
|
||||
|
||||
var kvct string
|
||||
@@ -231,7 +248,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "Metal" {
|
||||
if len(gpus) > 0 && gpus[0].Library == "Metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
} else if len(gpus) > 1 {
|
||||
// multigpu should always use the partial graph size
|
||||
@@ -256,7 +273,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
gpuAllocations := make([]uint64, len(gpus))
|
||||
type gs struct {
|
||||
i int
|
||||
g *discover.GpuInfo
|
||||
g *ml.DeviceInfo
|
||||
}
|
||||
gpusWithSpace := []gs{}
|
||||
for i := range gpus {
|
||||
@@ -265,19 +282,11 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
gzo = gpuZeroOverhead
|
||||
}
|
||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
|
||||
var compute string
|
||||
if gpus[i].Library == "ROCm" {
|
||||
compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
} else {
|
||||
compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
}
|
||||
|
||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory()+2*layerSize {
|
||||
slog.Debug("gpu has too little memory to allocate any layers",
|
||||
"id", gpus[i].ID,
|
||||
"library", gpus[i].Library,
|
||||
"variant", gpus[i].Variant,
|
||||
"compute", compute,
|
||||
"compute", gpus[i].Compute(),
|
||||
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
|
||||
"name", gpus[i].Name,
|
||||
"total", format.HumanBytes2(gpus[i].TotalMemory),
|
||||
@@ -291,7 +300,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
continue
|
||||
}
|
||||
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
|
||||
gpuAllocations[i] += gpus[i].MinimumMemory + layerSize // We hold off on graph until we know partial vs. full
|
||||
gpuAllocations[i] += gpus[i].MinimumMemory() + layerSize // We hold off on graph until we know partial vs. full
|
||||
}
|
||||
|
||||
var gpuZeroID int
|
||||
@@ -397,7 +406,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
VRAMSize: 0,
|
||||
GPUSizes: []uint64{},
|
||||
|
||||
inferenceLibrary: gpus[0].Library,
|
||||
inferenceLibrary: strings.Join(libraries, ","),
|
||||
layersRequested: opts.NumGPU,
|
||||
layersModel: int(f.KV().BlockCount()) + 1,
|
||||
availableList: availableList,
|
||||
@@ -411,7 +420,7 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
projectorGraph: ollamaEngineProjectorGraph,
|
||||
}
|
||||
|
||||
if gpus[0].Library == "cpu" {
|
||||
if len(gpus) == 0 {
|
||||
return estimate
|
||||
}
|
||||
if layerCount == 0 {
|
||||
|
||||
@@ -10,7 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
@@ -54,13 +54,7 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
}
|
||||
|
||||
// Simple CPU scenario
|
||||
gpus := []discover.GpuInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "cpu",
|
||||
},
|
||||
},
|
||||
}
|
||||
gpus := []ml.DeviceInfo{}
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
@@ -77,19 +71,17 @@ func TestEstimateGPULayers(t *testing.T) {
|
||||
memoryLayerOutput := uint64(4)
|
||||
|
||||
// Dual CUDA scenario with asymmetry
|
||||
gpuMinimumMemory := uint64(2048)
|
||||
gpus = []discover.GpuInfo{
|
||||
gpuMinimumMemory := uint64(457 * format.MebiByte)
|
||||
gpus = []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "cuda",
|
||||
Library: "CUDA",
|
||||
},
|
||||
MinimumMemory: gpuMinimumMemory,
|
||||
},
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
Library: "cuda",
|
||||
Library: "CUDA",
|
||||
},
|
||||
MinimumMemory: gpuMinimumMemory,
|
||||
},
|
||||
}
|
||||
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
|
||||
|
||||
555
llm/server.go
555
llm/server.go
@@ -27,7 +27,6 @@ import (
|
||||
"golang.org/x/sync/semaphore"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
@@ -66,11 +65,11 @@ func (e filteredEnv) LogValue() slog.Value {
|
||||
|
||||
type LlamaServer interface {
|
||||
ModelPath() string
|
||||
Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error)
|
||||
Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error)
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@@ -115,7 +114,7 @@ type llamaServer struct {
|
||||
llmServer
|
||||
|
||||
ggml *ggml.GGML
|
||||
gpus discover.GpuInfoList // The set of GPUs covered by the memory estimate
|
||||
gpus []ml.DeviceInfo // The set of GPUs covered by the memory estimate
|
||||
estimate MemoryEstimate
|
||||
}
|
||||
|
||||
@@ -146,7 +145,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
}
|
||||
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
func NewLlamaServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
var err error
|
||||
@@ -179,7 +178,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()}
|
||||
|
||||
defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount()
|
||||
defaultThreads := systemInfo.ThreadCount
|
||||
if opts.NumThread > 0 {
|
||||
loadRequest.NumThreads = opts.NumThread
|
||||
} else if defaultThreads > 0 {
|
||||
@@ -200,7 +199,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
|
||||
// This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset
|
||||
// that can handle it.
|
||||
if fa && !gpus.FlashAttentionSupported() {
|
||||
if fa && !ml.FlashAttentionSupported(gpus) {
|
||||
slog.Warn("flash attention enabled but not supported by gpu")
|
||||
fa = false
|
||||
}
|
||||
@@ -227,218 +226,170 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct)
|
||||
}
|
||||
|
||||
availableLibs := make(map[string]string)
|
||||
if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil {
|
||||
for _, entry := range entries {
|
||||
availableLibs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name())
|
||||
}
|
||||
gpuLibs := ml.LibraryPaths(gpus)
|
||||
status := NewStatusWriter(os.Stderr)
|
||||
cmd, port, err := StartRunner(
|
||||
textProcessor != nil,
|
||||
modelPath,
|
||||
gpuLibs,
|
||||
status,
|
||||
ml.GetVisibleDevicesEnv(gpus),
|
||||
)
|
||||
|
||||
s := llmServer{
|
||||
port: port,
|
||||
cmd: cmd,
|
||||
status: status,
|
||||
options: opts,
|
||||
modelPath: modelPath,
|
||||
loadRequest: loadRequest,
|
||||
llamaModel: llamaModel,
|
||||
llamaModelLock: &sync.Mutex{},
|
||||
textProcessor: textProcessor,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
loadStart: time.Now(),
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
var gpuLibs []string
|
||||
for _, gpu := range gpus {
|
||||
gpuLibs = append(gpuLibs, gpu.RunnerName())
|
||||
}
|
||||
|
||||
requested := envconfig.LLMLibrary()
|
||||
if availableLibs[requested] != "" {
|
||||
slog.Info("using requested gpu library", "requested", requested)
|
||||
gpuLibs = []string{requested}
|
||||
}
|
||||
|
||||
var compatible []string
|
||||
for _, gpuLib := range gpuLibs {
|
||||
var matchingLibs []string
|
||||
for k := range availableLibs {
|
||||
// exact match first
|
||||
if k == gpuLib {
|
||||
matchingLibs = append([]string{k}, matchingLibs...)
|
||||
continue
|
||||
}
|
||||
|
||||
// then match the family (e.g. 'cuda')
|
||||
if strings.Split(k, "_")[0] == strings.Split(gpuLib, "_")[0] {
|
||||
matchingLibs = append(matchingLibs, k)
|
||||
}
|
||||
}
|
||||
|
||||
if len(matchingLibs) > 0 {
|
||||
compatible = append(compatible, matchingLibs[0])
|
||||
}
|
||||
}
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
var msg string
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
err := s.cmd.Wait()
|
||||
// Favor a more detailed message over the process exit status
|
||||
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
||||
slog.Error("llama runner terminated", "error", err)
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
}
|
||||
s.done <- errors.New(s.status.LastErrMsg)
|
||||
} else {
|
||||
s.done <- err
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func StartRunner(ollamaEngine bool, modelPath string, gpuLibs []string, out io.Writer, extraEnvs map[string]string) (cmd *exec.Cmd, port int, err error) {
|
||||
var exe string
|
||||
exe, err = os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
// without any LD_LIBRARY_PATH flags
|
||||
for {
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed, using random port")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
params := []string{"runner"}
|
||||
if textProcessor != nil {
|
||||
// New engine
|
||||
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
|
||||
params = append(params, "--ollama-engine")
|
||||
}
|
||||
params = append(params, "--model", modelPath)
|
||||
params = append(params, "--port", strconv.Itoa(port))
|
||||
|
||||
var pathEnv string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
pathEnv = "PATH"
|
||||
case "darwin":
|
||||
pathEnv = "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
pathEnv = "LD_LIBRARY_PATH"
|
||||
}
|
||||
|
||||
// Note: we always put our dependency paths first
|
||||
// since these are the exact version we compiled/linked against
|
||||
libraryPaths := []string{discover.LibOllamaPath}
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
|
||||
}
|
||||
|
||||
ggmlPaths := []string{discover.LibOllamaPath}
|
||||
for _, c := range compatible {
|
||||
if libpath, ok := availableLibs[c]; ok {
|
||||
slog.Debug("adding gpu library", "path", libpath)
|
||||
libraryPaths = append([]string{libpath}, libraryPaths...)
|
||||
ggmlPaths = append(ggmlPaths, libpath)
|
||||
}
|
||||
}
|
||||
|
||||
for _, gpu := range gpus {
|
||||
if gpu.DependencyPath != nil {
|
||||
slog.Debug("adding gpu dependency paths", "paths", gpu.DependencyPath)
|
||||
libraryPaths = append(gpu.DependencyPath, libraryPaths...)
|
||||
ggmlPaths = append(ggmlPaths, gpu.DependencyPath...)
|
||||
}
|
||||
}
|
||||
|
||||
// finally, add the root library path
|
||||
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
||||
|
||||
s := llmServer{
|
||||
port: port,
|
||||
cmd: exec.Command(exe, params...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: modelPath,
|
||||
loadRequest: loadRequest,
|
||||
llamaModel: llamaModel,
|
||||
llamaModelLock: &sync.Mutex{},
|
||||
textProcessor: textProcessor,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
loadStart: time.Now(),
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||
envWorkarounds := gpus.GetVisibleDevicesEnv()
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path variable with our adjusted version
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||
for i := range s.cmd.Env {
|
||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else if len(envWorkarounds) != 0 {
|
||||
for j, kv := range envWorkarounds {
|
||||
tmp := strings.SplitN(kv, "=", 2)
|
||||
if strings.EqualFold(cmp[0], tmp[0]) {
|
||||
s.cmd.Env[i] = kv
|
||||
envWorkaroundDone[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if pathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i, done := range envWorkaroundDone {
|
||||
if !done {
|
||||
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("starting runner", "cmd", s.cmd)
|
||||
slog.Debug("subprocess", "", filteredEnv(s.cmd.Env))
|
||||
|
||||
if err = s.cmd.Start(); err != nil {
|
||||
var msg string
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if len(compatible) == 0 {
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Warn("unable to start runner with compatible gpu", "error", err, "compatible", compatible)
|
||||
compatible = compatible[1:]
|
||||
continue
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
err := s.cmd.Wait()
|
||||
// Favor a more detailed message over the process exit status
|
||||
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
|
||||
slog.Error("llama runner terminated", "error", err)
|
||||
if strings.Contains(s.status.LastErrMsg, "unknown model") {
|
||||
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
|
||||
}
|
||||
s.done <- errors.New(s.status.LastErrMsg)
|
||||
} else {
|
||||
s.done <- err
|
||||
}
|
||||
}()
|
||||
|
||||
if textProcessor != nil {
|
||||
return &ollamaServer{llmServer: s}, nil
|
||||
} else {
|
||||
return &llamaServer{llmServer: s, ggml: f}, nil
|
||||
port = 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed, using random port")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
params := []string{"runner"}
|
||||
if ollamaEngine {
|
||||
params = append(params, "--ollama-engine")
|
||||
}
|
||||
if modelPath != "" {
|
||||
params = append(params, "--model", modelPath)
|
||||
}
|
||||
params = append(params, "--port", strconv.Itoa(port))
|
||||
|
||||
var pathEnv string
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
pathEnv = "PATH"
|
||||
case "darwin":
|
||||
pathEnv = "DYLD_LIBRARY_PATH"
|
||||
default:
|
||||
pathEnv = "LD_LIBRARY_PATH"
|
||||
}
|
||||
|
||||
// Note: we always put our dependency paths first
|
||||
// since these are the exact version we compiled/linked against
|
||||
libraryPaths := append([]string{}, gpuLibs...)
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
libraryPaths = append(libraryPaths, filepath.SplitList(libraryPath)...)
|
||||
}
|
||||
|
||||
cmd = exec.Command(exe, params...)
|
||||
|
||||
cmd.Env = os.Environ()
|
||||
cmd.Stdout = out
|
||||
cmd.Stderr = out
|
||||
cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path variable with our adjusted version
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
extraEnvsDone := map[string]bool{}
|
||||
for k := range extraEnvs {
|
||||
extraEnvsDone[k] = false
|
||||
}
|
||||
for i := range cmd.Env {
|
||||
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(gpuLibs, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else if len(extraEnvs) != 0 {
|
||||
for k, v := range extraEnvs {
|
||||
if strings.EqualFold(cmp[0], k) {
|
||||
cmd.Env[i] = k + "=" + v
|
||||
extraEnvsDone[k] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if pathNeeded {
|
||||
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(gpuLibs, string(filepath.ListSeparator)))
|
||||
}
|
||||
for k, done := range extraEnvsDone {
|
||||
if !done {
|
||||
cmd.Env = append(cmd.Env, k+"="+extraEnvs[k])
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("starting runner", "cmd", cmd)
|
||||
slog.Debug("subprocess", "", filteredEnv(cmd.Env))
|
||||
|
||||
if err = cmd.Start(); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
err = nil
|
||||
return
|
||||
}
|
||||
|
||||
func (s *llmServer) ModelPath() string {
|
||||
@@ -497,47 +448,58 @@ type LoadResponse struct {
|
||||
|
||||
var ErrLoadRequiredFull = errors.New("unable to load full model on GPU")
|
||||
|
||||
func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.System.FreeSwap
|
||||
func (s *llamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
systemTotalMemory := systemInfo.TotalMemory
|
||||
systemFreeMemory := systemInfo.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
|
||||
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
if g == nil {
|
||||
if !requireFull {
|
||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
} else {
|
||||
if len(gpus) == 0 || s.options.NumGPU == 0 {
|
||||
if !verifyCPUFit(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, systemInfo, s.numParallel) {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||
return nil, ErrLoadRequiredFull
|
||||
return nil, fmt.Errorf("model requires more system memory than is currently available %w", ErrLoadRequiredFull)
|
||||
}
|
||||
} else {
|
||||
g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
if g == nil {
|
||||
if !requireFull {
|
||||
g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel)
|
||||
} else {
|
||||
slog.Info("model requires more memory than is currently available, evicting a model to make space", "estimate", s.estimate)
|
||||
return nil, ErrLoadRequiredFull
|
||||
}
|
||||
}
|
||||
gpus = g
|
||||
}
|
||||
|
||||
gpus = g
|
||||
s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel)
|
||||
|
||||
if len(gpus) > 1 || gpus[0].Library != "cpu" {
|
||||
if len(gpus) >= 1 {
|
||||
switch {
|
||||
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory:
|
||||
case s.options.NumGPU == 0:
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case gpus[0].Library == "Metal" && s.estimate.VRAMSize > systemInfo.TotalMemory:
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
s.options.NumGPU = 0
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case gpus[0].Library != "Metal" && s.estimate.Layers == 0:
|
||||
// Don't bother loading into the GPU if no layers can fit
|
||||
gpus = discover.GpuInfoList{discover.GetCPUInfo()}
|
||||
case s.options.NumGPU < 0 && s.estimate.Layers > 0 && gpus[0].Library != "cpu":
|
||||
gpus = []ml.DeviceInfo{}
|
||||
case s.options.NumGPU < 0 && s.estimate.Layers > 0:
|
||||
s.options.NumGPU = s.estimate.Layers
|
||||
}
|
||||
} else {
|
||||
s.options.NumGPU = 0
|
||||
}
|
||||
|
||||
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||
if runtime.GOOS != "darwin" {
|
||||
systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize
|
||||
available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap
|
||||
available := systemInfo.FreeMemory + systemInfo.FreeSwap
|
||||
if systemMemoryRequired > available {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap))
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
|
||||
}
|
||||
}
|
||||
@@ -564,10 +526,10 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
||||
// Windows CUDA should not use mmap for best performance
|
||||
// Linux with a model larger than free space, mmap leads to thrashing
|
||||
// For CPU loads we want the memory to be allocated, not FS cache
|
||||
if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
|
||||
(gpus[0].Library == "cpu" && s.options.UseMMap == nil) ||
|
||||
(gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
if (runtime.GOOS == "windows" && len(gpus) > 0 && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) ||
|
||||
(runtime.GOOS == "linux" && systemInfo.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) ||
|
||||
(len(gpus) == 0 && s.options.UseMMap == nil) ||
|
||||
(len(gpus) > 0 && gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) ||
|
||||
(s.options.UseMMap != nil && !*s.options.UseMMap) {
|
||||
s.loadRequest.UseMmap = false
|
||||
}
|
||||
@@ -605,8 +567,8 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi
|
||||
|
||||
// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment
|
||||
// of particular layers onto GPUs
|
||||
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.GpuInfoList, numGPU int) ml.GPULayersList {
|
||||
if numGPU <= 0 {
|
||||
func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus []ml.DeviceInfo, numGPU int) ml.GPULayersList {
|
||||
if numGPU <= 0 || len(gpus) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -662,7 +624,7 @@ func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.Gpu
|
||||
// allowing for faster iteration, but may return less information.
|
||||
//
|
||||
// Returns the list of GPU IDs that were used in the final allocation on success
|
||||
func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
|
||||
func (s *ollamaServer) Load(ctx context.Context, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
var success bool
|
||||
defer func() {
|
||||
if !success {
|
||||
@@ -675,24 +637,21 @@ func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requ
|
||||
|
||||
slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU)
|
||||
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.System.FreeSwap
|
||||
systemTotalMemory := systemInfo.TotalMemory
|
||||
systemFreeMemory := systemInfo.FreeMemory
|
||||
systemSwapFreeMemory := systemInfo.FreeSwap
|
||||
slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
|
||||
|
||||
if !(len(gpus) == 1 && gpus[0].Library == "cpu") {
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
|
||||
"available", format.HumanBytes2(available),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
for _, gpu := range gpus {
|
||||
available := gpu.FreeMemory - envconfig.GpuOverhead() - gpu.MinimumMemory()
|
||||
if gpu.FreeMemory < envconfig.GpuOverhead()+gpu.MinimumMemory() {
|
||||
available = 0
|
||||
}
|
||||
slog.Info("gpu memory", "id", gpu.ID, "library", gpu.Library,
|
||||
"available", format.HumanBytes2(available),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
"minimum", format.HumanBytes2(gpu.MinimumMemory()),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()))
|
||||
}
|
||||
|
||||
pastAllocations := make(map[uint64]struct{})
|
||||
@@ -762,7 +721,6 @@ nextOperation:
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Debug("new layout created", "layers", newGPULayers)
|
||||
|
||||
s.loadRequest.GPULayers = newGPULayers
|
||||
@@ -808,15 +766,12 @@ nextOperation:
|
||||
// Memory allocation failed even though we created a layout that we thought should
|
||||
// fit in available memory. This could happen if either our free memory reports
|
||||
// are incorrect or if available memory is changing between layout and allocation
|
||||
// time. Apply an exponential backoff to try to find the real amount of available
|
||||
// space.
|
||||
// time. Apply a backoff to try to find the real amount of available space.
|
||||
if backoff > 1 {
|
||||
slog.Warn("memory layout cannot be allocated", "memory", resp.Memory)
|
||||
return nil, errors.New("memory layout cannot be allocated")
|
||||
} else if backoff == 0 {
|
||||
backoff = 0.01
|
||||
} else {
|
||||
backoff *= 2
|
||||
backoff += 0.1
|
||||
}
|
||||
|
||||
slog.Info("model layout did not fit, applying backoff", "backoff", fmt.Sprintf("%.2f", backoff))
|
||||
@@ -864,20 +819,27 @@ func uniqueDeviceIDs(gpuLayers ml.GPULayersList) []ml.DeviceID {
|
||||
// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph
|
||||
// - Assigning layers
|
||||
// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs discover.GpuInfoList, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
||||
if s.totalLayers == 0 || s.options.NumGPU == 0 || len(systemGPUs) == 0 || (len(systemGPUs) == 1 && systemGPUs[0].Library == "cpu") {
|
||||
return ml.GPULayersList{}, nil
|
||||
}
|
||||
|
||||
gpus := append(make(discover.GpuInfoList, 0, len(systemGPUs)), systemGPUs...)
|
||||
sort.Sort(sort.Reverse(discover.ByFreeMemory(gpus)))
|
||||
|
||||
func (s *ollamaServer) createLayout(systemInfo ml.SystemInfo, systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) {
|
||||
if memory == nil {
|
||||
memory = &ml.BackendMemory{CPU: ml.DeviceMemory{
|
||||
Weights: make([]uint64, s.totalLayers),
|
||||
Cache: make([]uint64, s.totalLayers),
|
||||
}}
|
||||
}
|
||||
gpuLayers, layers, err := s.buildLayout(systemGPUs, memory, requireFull, backoff)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = s.verifyLayout(systemInfo, memory, requireFull, gpuLayers, layers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gpuLayers, nil
|
||||
}
|
||||
|
||||
func (s *ollamaServer) buildLayout(systemGPUs []ml.DeviceInfo, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, []uint64, error) {
|
||||
gpus := append(make([]ml.DeviceInfo, 0, len(systemGPUs)), systemGPUs...)
|
||||
sort.Sort(sort.Reverse(ml.ByFreeMemory(gpus)))
|
||||
|
||||
layers := make([]uint64, len(memory.CPU.Weights))
|
||||
for i := range layers {
|
||||
@@ -891,7 +853,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
}
|
||||
|
||||
gpuLayers := ml.GPULayersList{}
|
||||
for _, gl := range gpus.ByLibrary() {
|
||||
for _, gl := range ml.ByLibrary(gpus) {
|
||||
// If a GPU already has a graph allocated on it, then we should continue to use it.
|
||||
// Otherwise, we lose information that we got from previous allocations, which can
|
||||
// cause cycling. Plus, we get more information about required allocation from each
|
||||
@@ -905,7 +867,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
lastUsedGPU = i
|
||||
}
|
||||
|
||||
reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph
|
||||
reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory() + envconfig.GpuOverhead() + memory.GPUs[j].Graph
|
||||
if gl[i].FreeMemory > reserved {
|
||||
gl[i].FreeMemory -= reserved
|
||||
} else {
|
||||
@@ -914,7 +876,7 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
|
||||
slog.Debug("available gpu", "id", gl[i].ID, "library", gl[i].Library,
|
||||
"available layer vram", format.HumanBytes2(gl[i].FreeMemory),
|
||||
"backoff", fmt.Sprintf("%.2f", backoff), "minimum", format.HumanBytes2(gl[i].MinimumMemory),
|
||||
"backoff", fmt.Sprintf("%.2f", backoff), "minimum", format.HumanBytes2(gl[i].MinimumMemory()),
|
||||
"overhead", format.HumanBytes2(envconfig.GpuOverhead()),
|
||||
"graph", format.HumanBytes2(memory.GPUs[j].Graph))
|
||||
|
||||
@@ -933,7 +895,11 @@ func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs d
|
||||
gpuLayers = libraryGpuLayers
|
||||
}
|
||||
}
|
||||
return gpuLayers, layers, nil
|
||||
}
|
||||
|
||||
// verifyLayout ensures that we don't exceed limits, such as requirements about partial offloading or system memory
|
||||
func (s *ollamaServer) verifyLayout(systemInfo ml.SystemInfo, memory *ml.BackendMemory, requireFull bool, gpuLayers ml.GPULayersList, layers []uint64) error {
|
||||
// These sizes will only increase as we go through additional iterations and get additional information.
|
||||
cpuSize := memory.InputWeights + memory.CPU.Graph
|
||||
var vramSize uint64
|
||||
@@ -961,24 +927,24 @@ nextLayer:
|
||||
|
||||
if requireFull {
|
||||
if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) {
|
||||
return nil, ErrLoadRequiredFull
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
|
||||
if cpuSize > systemInfo.System.FreeMemory {
|
||||
return nil, ErrLoadRequiredFull
|
||||
if cpuSize > systemInfo.FreeMemory {
|
||||
return ErrLoadRequiredFull
|
||||
}
|
||||
}
|
||||
|
||||
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||
if runtime.GOOS != "darwin" {
|
||||
available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap
|
||||
available := systemInfo.FreeMemory + systemInfo.FreeSwap
|
||||
if cpuSize > available {
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap))
|
||||
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
|
||||
slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.TotalMemory), "free", format.HumanBytes2(systemInfo.FreeMemory), "swap", format.HumanBytes2(systemInfo.FreeSwap))
|
||||
return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available))
|
||||
}
|
||||
} else {
|
||||
if vramSize > systemInfo.System.TotalMemory {
|
||||
if vramSize > systemInfo.TotalMemory {
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
s.options.NumGPU = 0
|
||||
@@ -990,11 +956,11 @@ nextLayer:
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
}
|
||||
|
||||
return gpuLayers, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment
|
||||
func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
|
||||
func assignLayers(layers []uint64, gpus []ml.DeviceInfo, requireFull bool, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) {
|
||||
// If we can't fit everything then prefer offloading layers other than the output layer
|
||||
for range 2 {
|
||||
// requestedLayers may be -1 if nothing was requested
|
||||
@@ -1028,7 +994,7 @@ func assignLayers(layers []uint64, gpus discover.GpuInfoList, requireFull bool,
|
||||
// findBestFit binary searches to find the smallest capacity factor that can fit
|
||||
// the max number of layers. The capacity factor is multiplied by the free space on
|
||||
// each GPU and a small one will force even balancing.
|
||||
func findBestFit(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
|
||||
func findBestFit(layers []uint64, gpus []ml.DeviceInfo, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) {
|
||||
var high float32 = 1
|
||||
var low float32 = 0
|
||||
|
||||
@@ -1053,12 +1019,11 @@ func findBestFit(layers []uint64, gpus discover.GpuInfoList, requestedLayers int
|
||||
low = mid
|
||||
}
|
||||
}
|
||||
|
||||
return bestAssignments
|
||||
}
|
||||
|
||||
// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space
|
||||
func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, requestedLayers int) (gpuLayers ml.GPULayersList) {
|
||||
func greedyFit(layers []uint64, gpus []ml.DeviceInfo, capacity float32, requestedLayers int) (gpuLayers ml.GPULayersList) {
|
||||
device := len(gpus) - 1
|
||||
gpuLayers = ml.GPULayersList{{DeviceID: gpus[device].DeviceID}}
|
||||
freeSpace := uint64(float32(gpus[device].FreeMemory) * capacity)
|
||||
@@ -1082,7 +1047,6 @@ func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, req
|
||||
freeSpace = uint64(float32(gpus[device].FreeMemory) * capacity)
|
||||
}
|
||||
}
|
||||
|
||||
return gpuLayers
|
||||
}
|
||||
|
||||
@@ -1581,14 +1545,16 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
Content string `json:"content"`
|
||||
Truncate bool `json:"truncate"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
PromptEvalCount int `json:"prompt_eval_count"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||
logutil.Trace("embedding request", "input", input)
|
||||
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
@@ -1597,51 +1563,54 @@ func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, err
|
||||
} else {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
}
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status)
|
||||
return nil, 0, fmt.Errorf("unexpected server status: %s", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input, Truncate: truncate})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
return nil, 0, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
return nil, 0, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
return nil, 0, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
return nil, 0, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm embedding error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
return nil, 0, api.StatusError{
|
||||
StatusCode: resp.StatusCode,
|
||||
ErrorMessage: string(body),
|
||||
}
|
||||
}
|
||||
|
||||
var e EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &e); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
return nil, 0, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return e.Embedding, nil
|
||||
return e.Embedding, e.PromptEvalCount, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
@@ -1814,7 +1783,7 @@ func (s *ollamaServer) VRAMByGPU(id ml.DeviceID) uint64 {
|
||||
}
|
||||
|
||||
func (s *ollamaServer) GetDeviceInfos(ctx context.Context) []ml.DeviceInfo {
|
||||
devices, err := discover.GetDevicesFromRunner(ctx, s)
|
||||
devices, err := ml.GetDevicesFromRunner(ctx, s)
|
||||
if err != nil {
|
||||
if s.cmd != nil && s.cmd.ProcessState == nil {
|
||||
// Still running but hit an error, log
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"golang.org/x/sync/semaphore"
|
||||
@@ -20,6 +19,8 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
free int
|
||||
}
|
||||
|
||||
minMemory := 457 * format.MebiByte
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
gpus []gpu
|
||||
@@ -37,91 +38,91 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "Full single GPU",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Partial single GPU",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 1",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 0",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 0,
|
||||
expected: ml.GPULayersList{},
|
||||
},
|
||||
{
|
||||
name: "Single GPU with numGPU 999",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{0, 1, 2, 3}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU fits on one",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU split",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1, 2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU partial",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 1",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 2",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 2,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{1}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU numGPU 999",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: 999,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1"}, Layers: []int{0, 1}}, {DeviceID: ml.DeviceID{ID: "gpu0"}, Layers: []int{2}}},
|
||||
},
|
||||
{
|
||||
name: "Multi GPU different libraries",
|
||||
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128 * format.MebiByte}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{Library: "CUDA", ID: "gpu0"}, free: 128*format.MebiByte + minMemory}, {id: ml.DeviceID{Library: "ROCm", ID: "gpu1"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
expected: ml.GPULayersList{{DeviceID: ml.DeviceID{ID: "gpu1", Library: "ROCm"}, Layers: []int{0, 1}}},
|
||||
},
|
||||
{
|
||||
name: "requireFull",
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256 * format.MebiByte}},
|
||||
gpus: []gpu{{id: ml.DeviceID{ID: "gpu0"}, free: 256*format.MebiByte + minMemory}},
|
||||
layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte},
|
||||
numGPU: -1,
|
||||
requireFull: true,
|
||||
@@ -139,12 +140,12 @@ func TestLLMServerFitGPU(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var systemInfo discover.SystemInfo
|
||||
systemInfo.System.TotalMemory = format.GibiByte
|
||||
systemInfo.System.FreeMemory = 512 * format.MebiByte
|
||||
systemInfo.System.FreeSwap = 256 * format.MebiByte
|
||||
var systemInfo ml.SystemInfo
|
||||
systemInfo.TotalMemory = format.GibiByte
|
||||
systemInfo.FreeMemory = 512 * format.MebiByte
|
||||
systemInfo.FreeSwap = 256 * format.MebiByte
|
||||
|
||||
gpus := make(discover.GpuInfoList, len(tt.gpus))
|
||||
gpus := make([]ml.DeviceInfo, len(tt.gpus))
|
||||
for i := range tt.gpus {
|
||||
gpus[i].DeviceID = tt.gpus[i].id
|
||||
gpus[i].FreeMemory = uint64(tt.gpus[i].free)
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
@@ -44,7 +45,8 @@ type RetrieveWriter struct {
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
model string
|
||||
encodingFormat string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
@@ -254,7 +256,7 @@ func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse, w.encodingFormat))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -348,6 +350,14 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
// Validate encoding_format parameter
|
||||
if req.EncodingFormat != "" {
|
||||
if !strings.EqualFold(req.EncodingFormat, "float") && !strings.EqualFold(req.EncodingFormat, "base64") {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, fmt.Sprintf("Invalid value for 'encoding_format' = %s. Supported values: ['float', 'base64'].", req.EncodingFormat)))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
@@ -371,8 +381,9 @@ func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
encodingFormat: req.EncodingFormat,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
220
middleware/openai_encoding_format_test.go
Normal file
220
middleware/openai_encoding_format_test.go
Normal file
@@ -0,0 +1,220 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
func TestEmbeddingsMiddleware_EncodingFormats(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
encodingFormat string
|
||||
expectType string // "array" or "string"
|
||||
verifyBase64 bool
|
||||
}{
|
||||
{"float format", "float", "array", false},
|
||||
{"base64 format", "base64", "string", true},
|
||||
{"default format", "", "array", false},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: [][]float32{{0.1, -0.2, 0.3}},
|
||||
PromptEvalCount: 5,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := `{"input": "test", "model": "test-model"`
|
||||
if tc.encodingFormat != "" {
|
||||
body += `, "encoding_format": "` + tc.encodingFormat + `"`
|
||||
}
|
||||
body += `}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result openai.EmbeddingList
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Data) != 1 {
|
||||
t.Fatalf("expected 1 embedding, got %d", len(result.Data))
|
||||
}
|
||||
|
||||
switch tc.expectType {
|
||||
case "array":
|
||||
if _, ok := result.Data[0].Embedding.([]interface{}); !ok {
|
||||
t.Errorf("expected array, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "string":
|
||||
embStr, ok := result.Data[0].Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("expected string, got %T", result.Data[0].Embedding)
|
||||
} else if tc.verifyBase64 {
|
||||
decoded, err := base64.StdEncoding.DecodeString(embStr)
|
||||
if err != nil {
|
||||
t.Errorf("invalid base64: %v", err)
|
||||
} else if len(decoded) != 12 {
|
||||
t.Errorf("expected 12 bytes, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware_BatchWithBase64(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: [][]float32{
|
||||
{0.1, 0.2},
|
||||
{0.3, 0.4},
|
||||
{0.5, 0.6},
|
||||
},
|
||||
PromptEvalCount: 10,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
body := `{
|
||||
"input": ["hello", "world", "test"],
|
||||
"model": "test-model",
|
||||
"encoding_format": "base64"
|
||||
}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var result openai.EmbeddingList
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &result); err != nil {
|
||||
t.Fatalf("failed to unmarshal response: %v", err)
|
||||
}
|
||||
|
||||
if len(result.Data) != 3 {
|
||||
t.Fatalf("expected 3 embeddings, got %d", len(result.Data))
|
||||
}
|
||||
|
||||
// All should be base64 strings
|
||||
for i := range 3 {
|
||||
embeddingStr, ok := result.Data[i].Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("embedding %d: expected string, got %T", i, result.Data[i].Embedding)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(embeddingStr); err != nil {
|
||||
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||
}
|
||||
|
||||
// Check index
|
||||
if result.Data[i].Index != i {
|
||||
t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware_InvalidEncodingFormat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware())
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
encodingFormat string
|
||||
shouldFail bool
|
||||
}{
|
||||
{"valid: float", "float", false},
|
||||
{"valid: base64", "base64", false},
|
||||
{"valid: FLOAT (uppercase)", "FLOAT", false},
|
||||
{"valid: BASE64 (uppercase)", "BASE64", false},
|
||||
{"valid: Float (mixed)", "Float", false},
|
||||
{"valid: Base64 (mixed)", "Base64", false},
|
||||
{"invalid: json", "json", true},
|
||||
{"invalid: hex", "hex", true},
|
||||
{"invalid: invalid_format", "invalid_format", true},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
body := `{
|
||||
"input": "test",
|
||||
"model": "test-model",
|
||||
"encoding_format": "` + tc.encodingFormat + `"
|
||||
}`
|
||||
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
if tc.shouldFail {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Errorf("expected status 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatalf("failed to unmarshal error response: %v", err)
|
||||
}
|
||||
|
||||
if errResp.Error.Type != "invalid_request_error" {
|
||||
t.Errorf("expected error type 'invalid_request_error', got %q", errResp.Error.Type)
|
||||
}
|
||||
|
||||
if !strings.Contains(errResp.Error.Message, "encoding_format") {
|
||||
t.Errorf("expected error message to mention encoding_format, got %q", errResp.Error.Message)
|
||||
}
|
||||
} else {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d: %s", resp.Code, resp.Body.String())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
12
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -118,6 +118,18 @@ static dl_handle * dl_load_library(const fs::path & path) {
|
||||
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
|
||||
|
||||
HMODULE handle = LoadLibraryW(path.wstring().c_str());
|
||||
if (!handle) {
|
||||
DWORD error_code = GetLastError();
|
||||
std::string msg;
|
||||
LPSTR lpMsgBuf = NULL;
|
||||
DWORD bufLen = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
|
||||
NULL, error_code, MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&lpMsgBuf, 0, NULL);
|
||||
if (bufLen) {
|
||||
msg = lpMsgBuf;
|
||||
LocalFree(lpMsgBuf);
|
||||
GGML_LOG_INFO("%s unable to load library %s: %s\n", __func__, path_str(path).c_str(), msg.c_str());
|
||||
}
|
||||
}
|
||||
|
||||
SetErrorMode(old_mode);
|
||||
|
||||
|
||||
@@ -4159,7 +4159,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
if (!initialized) {
|
||||
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
|
||||
int driverVersion = 0;
|
||||
CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
|
||||
|
||||
for (int i = 0; i < ggml_cuda_info().device_count; i++) {
|
||||
ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context;
|
||||
@@ -4177,6 +4176,9 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
|
||||
|
||||
dev_ctx->major = prop.major;
|
||||
dev_ctx->minor = prop.minor;
|
||||
if (driverVersion == 0) {
|
||||
CUDA_CHECK(cudaDriverGetVersion(&driverVersion));
|
||||
}
|
||||
dev_ctx->driver_major = driverVersion / 1000;
|
||||
dev_ctx->driver_minor = (driverVersion - (dev_ctx->driver_major * 1000)) / 10;
|
||||
dev_ctx->integrated = prop.integrated;
|
||||
|
||||
215
ml/device.go
215
ml/device.go
@@ -3,15 +3,21 @@ package ml
|
||||
import (
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
)
|
||||
|
||||
// GPULayers is a set of layers to be allocated on a single GPU
|
||||
@@ -282,6 +288,20 @@ type DeviceInfo struct {
|
||||
LibraryPath []string
|
||||
}
|
||||
|
||||
type SystemInfo struct {
|
||||
// ThreadCount is the optimal number of threads to use for inference
|
||||
ThreadCount int `json:"threads,omitempty"`
|
||||
|
||||
// TotalMemory is the total amount of system memory
|
||||
TotalMemory uint64 `json:"total_memory,omitempty"`
|
||||
|
||||
// FreeMemory is the amount of memory currently available on the system for loading models
|
||||
FreeMemory uint64 `json:"free_memory,omitempty"`
|
||||
|
||||
// FreeSwap is the amount of system swap space reported as available
|
||||
FreeSwap uint64 `json:"free_swap,omitempty"`
|
||||
}
|
||||
|
||||
func (d DeviceInfo) Compute() string {
|
||||
// AMD gfx is encoded into the major minor in hex form
|
||||
if strings.EqualFold(d.Library, "ROCm") {
|
||||
@@ -294,6 +314,71 @@ func (d DeviceInfo) Driver() string {
|
||||
return strconv.Itoa(d.DriverMajor) + "." + strconv.Itoa(d.DriverMinor)
|
||||
}
|
||||
|
||||
// MinimumMemory reports the amount of memory that should be set aside
|
||||
// on the device for overhead (e.g. VRAM consumed by context structures independent
|
||||
// of model allocations)
|
||||
func (d DeviceInfo) MinimumMemory() uint64 {
|
||||
if d.Library == "Metal" {
|
||||
return 512 * format.MebiByte
|
||||
}
|
||||
return 457 * format.MebiByte
|
||||
}
|
||||
|
||||
// Sort by Free Space.
|
||||
// iGPUs are reported first, thus Reverse() yields the largest discrete GPU first
|
||||
type ByFreeMemory []DeviceInfo
|
||||
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool {
|
||||
if a[i].Integrated && !a[j].Integrated {
|
||||
return true
|
||||
} else if !a[i].Integrated && a[j].Integrated {
|
||||
return false
|
||||
}
|
||||
return a[i].FreeMemory < a[j].FreeMemory
|
||||
}
|
||||
|
||||
func ByLibrary(l []DeviceInfo) [][]DeviceInfo {
|
||||
resp := [][]DeviceInfo{}
|
||||
libs := []string{}
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
resp[i] = append(resp[i], info)
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
libs = append(libs, requested)
|
||||
resp = append(resp, []DeviceInfo{info})
|
||||
}
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func LibraryPaths(l []DeviceInfo) []string {
|
||||
var gpuLibs []string
|
||||
for _, gpu := range l {
|
||||
for _, dir := range gpu.LibraryPath {
|
||||
needed := true
|
||||
for _, existing := range gpuLibs {
|
||||
if dir == existing {
|
||||
needed = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if needed {
|
||||
gpuLibs = append(gpuLibs, dir)
|
||||
}
|
||||
}
|
||||
}
|
||||
return gpuLibs
|
||||
}
|
||||
|
||||
type DeviceComparison int
|
||||
|
||||
const (
|
||||
@@ -336,3 +421,133 @@ func (a DeviceInfo) IsBetter(b DeviceInfo) bool {
|
||||
sort.Sort(sort.Reverse(sort.StringSlice(cmp)))
|
||||
return cmp[0] == bLibSplit[1]
|
||||
}
|
||||
|
||||
// For each GPU, check if it does NOT support flash attention
|
||||
func FlashAttentionSupported(l []DeviceInfo) bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) ||
|
||||
gpu.Library == "ROCm"
|
||||
|
||||
if !supportsFA {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Given the list of GPUs this instantiation is targeted for,
|
||||
// figure out the visible devices environment variables
|
||||
func GetVisibleDevicesEnv(l []DeviceInfo) map[string]string {
|
||||
if len(l) == 0 {
|
||||
return nil
|
||||
}
|
||||
env := map[string]string{}
|
||||
for _, d := range l {
|
||||
d.updateVisibleDevicesEnv(env)
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (d DeviceInfo) updateVisibleDevicesEnv(env map[string]string) {
|
||||
var envVar string
|
||||
switch d.Library {
|
||||
case "ROCm":
|
||||
envVar = "ROCR_VISIBLE_DEVICES"
|
||||
if runtime.GOOS != "linux" {
|
||||
envVar = "HIP_VISIBLE_DEVICES"
|
||||
}
|
||||
case "Vulkan":
|
||||
envVar = "GGML_VK_VISIBLE_DEVICES"
|
||||
default:
|
||||
return
|
||||
}
|
||||
v, existing := env[envVar]
|
||||
if existing {
|
||||
v = v + ","
|
||||
}
|
||||
if d.FilteredID != "" {
|
||||
v = v + d.FilteredID
|
||||
} else {
|
||||
v = v + d.ID
|
||||
}
|
||||
env[envVar] = v
|
||||
}
|
||||
|
||||
type BaseRunner interface {
|
||||
// GetPort returns the localhost port number the runner is running on
|
||||
GetPort() int
|
||||
|
||||
// HasExited indicates if the runner is no longer running. This can be used during
|
||||
// bootstrap to detect if a given filtered device is incompatible and triggered an assert
|
||||
HasExited() bool
|
||||
}
|
||||
|
||||
type RunnerDiscovery interface {
|
||||
BaseRunner
|
||||
|
||||
// GetDeviceInfos will perform a query of the underlying device libraries
|
||||
// for device identification and free VRAM information
|
||||
// During bootstrap scenarios, this routine may take seconds to complete
|
||||
GetDeviceInfos(ctx context.Context) []DeviceInfo
|
||||
}
|
||||
|
||||
type FilteredRunnerDiscovery interface {
|
||||
RunnerDiscovery
|
||||
|
||||
// GetActiveDeviceIDs returns the filtered set of devices actively in
|
||||
// use by this runner for running models. If the runner is a bootstrap runner, no devices
|
||||
// will be active yet so no device IDs are returned.
|
||||
// This routine will not query the underlying device and will return immediately
|
||||
GetActiveDeviceIDs() []DeviceID
|
||||
}
|
||||
|
||||
func GetDevicesFromRunner(ctx context.Context, runner BaseRunner) ([]DeviceInfo, error) {
|
||||
var moreDevices []DeviceInfo
|
||||
port := runner.GetPort()
|
||||
tick := time.Tick(10 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("failed to finish discovery before timeout")
|
||||
case <-tick:
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/info", port), nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create request: %w", err)
|
||||
}
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
// slog.Warn("failed to send request", "error", err)
|
||||
if runner.HasExited() {
|
||||
return nil, fmt.Errorf("runner crashed")
|
||||
}
|
||||
continue
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
// old runner, fall back to bootstrapping model
|
||||
return nil, fmt.Errorf("llamarunner free vram reporting not supported")
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
slog.Warn("failed to read response", "error", err)
|
||||
continue
|
||||
}
|
||||
if resp.StatusCode != 200 {
|
||||
logutil.Trace("runner failed to discover free VRAM", "status", resp.StatusCode, "response", body)
|
||||
return nil, fmt.Errorf("runner error: %s", string(body))
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &moreDevices); err != nil {
|
||||
slog.Warn("unmarshal encode response", "error", err)
|
||||
continue
|
||||
}
|
||||
return moreDevices, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,6 @@ package gemma3
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/fs"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
"github.com/ollama/ollama/ml/nn/pooling"
|
||||
@@ -53,10 +52,5 @@ func newEmbedModel(c fs.Config) (model.Model, error) {
|
||||
poolingType: pooling.Type(c.Uint("pooling_type", 0)),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(
|
||||
kvcache.NewSWACache(int32(c.Uint("attention.sliding_window")), m.Shift),
|
||||
kvcache.NewCausalCache(m.Shift),
|
||||
)
|
||||
|
||||
return m, nil
|
||||
}
|
||||
|
||||
@@ -182,16 +182,18 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
for i, layer := range m.Layers {
|
||||
// gemma alternates between the sliding window (local) and causal (global)
|
||||
// kv cache every 6 layers
|
||||
cacheType := cacheTypeSWA
|
||||
if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
cache.SetLayer(i)
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
if cache != nil {
|
||||
cacheType := cacheTypeSWA
|
||||
if (i+1)%gemmaGlobalCacheCount == 0 {
|
||||
cacheType = cacheTypeCausal
|
||||
}
|
||||
cache.SetLayer(i)
|
||||
wc := cache.(*kvcache.WrapperCache)
|
||||
wc.SetLayerType(cacheType)
|
||||
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok {
|
||||
causal.SetCausal(ctx, kvcache.CausalOptions{Except: except})
|
||||
}
|
||||
}
|
||||
|
||||
var lastLayerOutputs ml.Tensor
|
||||
|
||||
@@ -65,7 +65,7 @@ func (m *TextModel) Forward(ctx ml.Context, batch input.Batch, cache kvcache.Cac
|
||||
cache.(*kvcache.WrapperCache).SetLayerType(layerType)
|
||||
|
||||
// inputPerLayer = inputsPerLayer[:, i, :]
|
||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2))
|
||||
inputPerLayer := inputsPerLayer.View(ctx, i*inputsPerLayer.Stride(1), inputsPerLayer.Dim(0), inputsPerLayer.Stride(2), inputsPerLayer.Dim(2)).Contiguous(ctx)
|
||||
hiddenStates = layer.Forward(ctx, hiddenStates, inputPerLayer, positions, one, cache, i >= firstSharedKeyValue, ropeBase, float64(m.activationSparsityScale[i]), &m.TextOptions)
|
||||
}
|
||||
|
||||
|
||||
@@ -45,6 +45,9 @@ func ParserForName(name string) Parser {
|
||||
case "qwen3-vl-instruct":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: false}
|
||||
return parser
|
||||
case "qwen3-vl-thinking":
|
||||
parser := &Qwen3VLParser{hasThinkingSupport: true}
|
||||
return parser
|
||||
case "passthrough":
|
||||
return &PassthroughParser{}
|
||||
case "harmony":
|
||||
|
||||
@@ -22,7 +22,6 @@ const (
|
||||
thinkingCloseTag = "</think>"
|
||||
)
|
||||
|
||||
// TODO(gguo): add a field for isThinking
|
||||
type Qwen3VLParser struct {
|
||||
state qwenParserState
|
||||
buffer strings.Builder
|
||||
@@ -34,21 +33,28 @@ func (p *Qwen3VLParser) HasToolSupport() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// TODO(gguo): changes this to reference an objects param
|
||||
func (p *Qwen3VLParser) HasThinkingSupport() bool {
|
||||
return p.hasThinkingSupport
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) initialState() qwenParserState {
|
||||
if p.HasThinkingSupport() { // has thinking, start from collecting thinking content
|
||||
return CollectingThinkingContent
|
||||
func (p *Qwen3VLParser) setInitialState(lastMessage *api.Message) {
|
||||
prefill := lastMessage != nil && lastMessage.Role == "assistant"
|
||||
if !p.HasThinkingSupport() {
|
||||
p.state = CollectingContent
|
||||
return
|
||||
}
|
||||
return CollectingContent
|
||||
|
||||
if prefill && lastMessage.Content != "" {
|
||||
p.state = CollectingContent
|
||||
return
|
||||
}
|
||||
|
||||
p.state = CollectingThinkingContent
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
|
||||
p.tools = tools
|
||||
p.state = p.initialState()
|
||||
p.setInitialState(lastMessage)
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -63,7 +69,8 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
events := p.parseEvents()
|
||||
|
||||
var toolCalls []api.ToolCall
|
||||
var sb strings.Builder
|
||||
var contentSb strings.Builder
|
||||
var thinkingSb strings.Builder
|
||||
for _, event := range events {
|
||||
switch event := event.(type) {
|
||||
case qwenEventRawToolCall:
|
||||
@@ -74,15 +81,15 @@ func (p *Qwen3VLParser) Add(s string, done bool) (content string, thinking strin
|
||||
}
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
case qwenEventThinkingContent:
|
||||
sb.WriteString(event.content)
|
||||
thinkingSb.WriteString(event.content)
|
||||
case qwenEventContent:
|
||||
// TODO(drifkin): if the same turn contains multiple interleaved content
|
||||
// events, we naively append them together here.
|
||||
sb.WriteString(event.content)
|
||||
contentSb.WriteString(event.content)
|
||||
}
|
||||
}
|
||||
|
||||
return sb.String(), "", toolCalls, nil
|
||||
return contentSb.String(), thinkingSb.String(), toolCalls, nil
|
||||
}
|
||||
|
||||
func (p *Qwen3VLParser) parseEvents() []qwenEvent {
|
||||
@@ -155,7 +162,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
case CollectingToolContent:
|
||||
if strings.Contains(p.buffer.String(), toolCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), toolCloseTag, 2)
|
||||
before := split[0]
|
||||
before := split[0] // do we also need to do it to tool calls?
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
@@ -169,13 +176,11 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
} else {
|
||||
return events, false
|
||||
}
|
||||
case CollectingThinkingContent: // so we want to hip the unambiguous stuff
|
||||
case CollectingThinkingContent:
|
||||
if strings.Contains(p.buffer.String(), thinkingCloseTag) {
|
||||
split := strings.SplitN(p.buffer.String(), thinkingCloseTag, 2)
|
||||
before := split[0]
|
||||
if len(before) == 0 {
|
||||
slog.Warn("qwen tool call closing tag found but no content before it")
|
||||
}
|
||||
// before := split[0]
|
||||
before := strings.TrimRightFunc(split[0], unicode.IsSpace)
|
||||
after := strings.TrimLeftFunc(split[1], unicode.IsSpace)
|
||||
if len(before) > 0 {
|
||||
events = append(events, qwenEventThinkingContent{content: before})
|
||||
@@ -184,7 +189,7 @@ func (p *Qwen3VLParser) eat() ([]qwenEvent, bool) {
|
||||
p.buffer.WriteString(after)
|
||||
p.state = CollectingContent
|
||||
return events, true
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 { // we see part of a close thinking tag
|
||||
} else if overlapLen := overlap(p.buffer.String(), thinkingCloseTag); overlapLen > 0 {
|
||||
beforePartialTag := p.buffer.String()[:len(p.buffer.String())-overlapLen]
|
||||
trailingWhitespaceLen := trailingWhitespaceLen(beforePartialTag)
|
||||
ambiguousStart := len(beforePartialTag) - trailingWhitespaceLen
|
||||
|
||||
@@ -344,3 +344,205 @@ func TestQwen3VLThinkingToolParser(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLParserState(t *testing.T) {
|
||||
cases := []struct {
|
||||
desc string
|
||||
hasThinking bool
|
||||
last *api.Message
|
||||
wantState qwenParserState
|
||||
}{
|
||||
{
|
||||
desc: "no thinking support => CollectingContent",
|
||||
hasThinking: false,
|
||||
last: nil,
|
||||
wantState: CollectingContent,
|
||||
},
|
||||
{
|
||||
desc: "thinking support, no last message => CollectingThinkingContent",
|
||||
hasThinking: true,
|
||||
last: nil,
|
||||
wantState: CollectingThinkingContent,
|
||||
},
|
||||
{
|
||||
desc: "thinking support, last assistant with empty content => CollectingThinkingContent",
|
||||
hasThinking: true,
|
||||
last: &api.Message{Role: "assistant", Content: ""},
|
||||
wantState: CollectingThinkingContent,
|
||||
},
|
||||
{
|
||||
desc: "thinking support, last assistant with content => CollectingContent",
|
||||
hasThinking: true,
|
||||
last: &api.Message{Role: "assistant", Content: "hello"},
|
||||
wantState: CollectingContent,
|
||||
},
|
||||
{
|
||||
desc: "thinking support, last is user => CollectingThinkingContent",
|
||||
hasThinking: true,
|
||||
last: &api.Message{Role: "user", Content: "hi"},
|
||||
wantState: CollectingThinkingContent,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: tc.hasThinking}
|
||||
parser.Init(nil, tc.last)
|
||||
if parser.state != tc.wantState {
|
||||
t.Errorf("%s: got state %v, want %v", tc.desc, parser.state, tc.wantState)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLThinkingParserWithThinkingPrefill(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []qwenEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "thinking prefill",
|
||||
steps: []step{
|
||||
{input: "abc</think>", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with content",
|
||||
steps: []step{
|
||||
{input: "abc</th", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}},
|
||||
{input: "ink> def", wantEvents: []qwenEvent{qwenEventContent{content: "def"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with fakeout",
|
||||
steps: []step{
|
||||
{input: "abc</think", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "abc"}}},
|
||||
{input: " fakeout </think", wantEvents: []qwenEvent{qwenEventThinkingContent{content: "</think fakeout"}}},
|
||||
{input: ">", wantEvents: []qwenEvent{}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with spaces",
|
||||
steps: []step{
|
||||
{input: " </think> starting content", wantEvents: []qwenEvent{qwenEventContent{content: "starting content"}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
last := &api.Message{Role: "assistant", Thinking: "i am thinking"} // so if there is thinking the test is still thinking
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLThinkingParserWithNonThinkingPrefill(t *testing.T) {
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []qwenEvent
|
||||
}
|
||||
|
||||
cases := []struct {
|
||||
desc string
|
||||
steps []step
|
||||
only bool
|
||||
}{
|
||||
{
|
||||
desc: "thinking prefill",
|
||||
steps: []step{
|
||||
{input: "abc</think>", wantEvents: []qwenEvent{qwenEventContent{content: "abc</think>"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with content",
|
||||
steps: []step{
|
||||
{input: "abc</th", wantEvents: []qwenEvent{qwenEventContent{content: "abc</th"}}},
|
||||
{input: "ink> def", wantEvents: []qwenEvent{qwenEventContent{content: "ink> def"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with fakeout",
|
||||
steps: []step{
|
||||
{input: "abc</think", wantEvents: []qwenEvent{qwenEventContent{content: "abc</think"}}},
|
||||
{input: " fakeout </think", wantEvents: []qwenEvent{qwenEventContent{content: " fakeout </think"}}},
|
||||
{input: ">", wantEvents: []qwenEvent{qwenEventContent{content: ">"}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
desc: "thinking prefill with spaces",
|
||||
steps: []step{
|
||||
{input: " </think> starting content", wantEvents: []qwenEvent{qwenEventContent{content: " </think> starting content"}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
last := &api.Message{Role: "assistant", Thinking: "i am thinking", Content: "i am content"} // so if there is thinking the test is still thinking
|
||||
|
||||
for _, tc := range cases {
|
||||
t.Run(tc.desc, func(t *testing.T) {
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
|
||||
for i, step := range tc.steps {
|
||||
parser.buffer.WriteString(step.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
|
||||
if len(gotEvents) == 0 && len(step.wantEvents) == 0 {
|
||||
// avoid deep equal on empty vs. nil slices
|
||||
continue
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(gotEvents, step.wantEvents) {
|
||||
t.Errorf("step %d: input %q: got events %#v, want %#v", i, step.input, gotEvents, step.wantEvents)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestQwen3VLThinkingParserStreamingAssistantPrefillContent(t *testing.T) {
|
||||
// last message is assistant with content ⇒ start in CollectingContent
|
||||
last := &api.Message{Role: "assistant", Content: "has content"}
|
||||
parser := Qwen3VLParser{hasThinkingSupport: true}
|
||||
parser.Init([]api.Tool{}, last)
|
||||
|
||||
type step struct {
|
||||
input string
|
||||
wantEvents []qwenEvent
|
||||
}
|
||||
|
||||
steps := []step{
|
||||
{input: "abc</think>", wantEvents: []qwenEvent{qwenEventContent{content: "abc</think>"}}},
|
||||
{input: "<tool_call>{\"name\": \"x\", \"arguments\": {}}</tool_call>", wantEvents: []qwenEvent{qwenEventRawToolCall{raw: "{\"name\": \"x\", \"arguments\": {}}"}}},
|
||||
}
|
||||
|
||||
for i, s := range steps {
|
||||
parser.buffer.WriteString(s.input)
|
||||
gotEvents := parser.parseEvents()
|
||||
if len(gotEvents) == 0 && len(s.wantEvents) == 0 {
|
||||
continue
|
||||
}
|
||||
if !reflect.DeepEqual(gotEvents, s.wantEvents) {
|
||||
t.Fatalf("step %d: input %q: got %#v, want %#v", i, s.input, gotEvents, s.wantEvents)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -48,13 +48,22 @@ func marshalWithSpaces(v any) ([]byte, error) {
|
||||
|
||||
type Qwen3VLRenderer struct {
|
||||
isThinking bool
|
||||
|
||||
useImgTags bool
|
||||
}
|
||||
|
||||
func (r *Qwen3VLRenderer) renderContent(content api.Message, doVisionCount bool) string {
|
||||
func (r *Qwen3VLRenderer) renderContent(content api.Message) string {
|
||||
// This assumes all images are at the front of the message - same assumption as ollama/ollama/runner.go
|
||||
var subSb strings.Builder
|
||||
for range content.Images {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
// TODO: (jmorganca): how to render this is different for different
|
||||
// model backends, and so we should eventually parameterize this or
|
||||
// only output a placeholder such as [img]
|
||||
if r.useImgTags {
|
||||
subSb.WriteString("[img]")
|
||||
} else {
|
||||
subSb.WriteString("<|vision_start|><|image_pad|><|vision_end|>")
|
||||
}
|
||||
}
|
||||
// TODO: support videos
|
||||
|
||||
@@ -88,7 +97,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
message := messages[i]
|
||||
if multiStepTool && message.Role == "user" {
|
||||
// Check if content starts with <tool_response> and ends with </tool_response>
|
||||
content := r.renderContent(message, true)
|
||||
content := r.renderContent(message)
|
||||
if !(strings.HasPrefix(content, "<tool_response>") && strings.HasSuffix(content, "</tool_response>")) {
|
||||
multiStepTool = false
|
||||
lastQueryIndex = i
|
||||
@@ -97,7 +106,7 @@ func (r *Qwen3VLRenderer) Render(messages []api.Message, tools []api.Tool, _ *ap
|
||||
}
|
||||
|
||||
for i, message := range messages {
|
||||
content := r.renderContent(message, true)
|
||||
content := r.renderContent(message)
|
||||
|
||||
lastMessage := i == len(messages)-1
|
||||
prefill := lastMessage && message.Role == "assistant"
|
||||
|
||||
@@ -9,11 +9,12 @@ import (
|
||||
|
||||
func TestQwen3VLNonThinkingRenderer(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
msgs []api.Message
|
||||
images []api.ImageData
|
||||
tools []api.Tool
|
||||
expected string
|
||||
name string
|
||||
msgs []api.Message
|
||||
images []api.ImageData
|
||||
tools []api.Tool
|
||||
useImgTags bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "prefill",
|
||||
@@ -90,6 +91,18 @@ I'll check the weather in San Francisco for you.<think>Speak poetry after the fi
|
||||
expected: `<|im_start|>user
|
||||
<|vision_start|><|image_pad|><|vision_end|>Describe this image.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
{
|
||||
name: "Image with image tags",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Describe this image.", Images: []api.ImageData{api.ImageData("img2")}},
|
||||
{Role: "assistant", Content: "Let me analyze this image."},
|
||||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img]Describe this image.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
{
|
||||
@@ -102,7 +115,18 @@ Let me analyze this image.`,
|
||||
<|im_start|>assistant
|
||||
`,
|
||||
},
|
||||
|
||||
{
|
||||
name: "Multiple images with image tags",
|
||||
msgs: []api.Message{
|
||||
{Role: "user", Content: "Describe these images.", Images: []api.ImageData{api.ImageData("img1"), api.ImageData("img2")}},
|
||||
{Role: "assistant", Content: "Let me analyze this image."},
|
||||
},
|
||||
useImgTags: true,
|
||||
expected: `<|im_start|>user
|
||||
[img][img]Describe these images.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
Let me analyze this image.`,
|
||||
},
|
||||
// // NOTE: solved with #12518: https://github.com/ollama/ollama/compare/main...drifkin/stable-tool-args
|
||||
// {
|
||||
// name: "with tools and response",
|
||||
@@ -485,7 +509,7 @@ I'll check.
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Qwen3VLRenderer{false}).Render(tt.msgs, tt.tools, nil)
|
||||
rendered, err := (&Qwen3VLRenderer{isThinking: false, useImgTags: tt.useImgTags}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -323,7 +323,7 @@ Speak poetry after the first sentence.</think><think>Speak poetry after the seco
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
rendered, err := (&Qwen3VLRenderer{true}).Render(tt.msgs, tt.tools, nil)
|
||||
rendered, err := (&Qwen3VLRenderer{isThinking: true}).Render(tt.msgs, tt.tools, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -17,6 +17,11 @@ type (
|
||||
}
|
||||
)
|
||||
|
||||
// RenderImgTags is a global flag that tells renderers to use [img] tags
|
||||
// for images. This is set by the Ollama server package on init, or left as
|
||||
// false for other environments where renderers are used
|
||||
var RenderImgTags bool
|
||||
|
||||
func (r *RendererRegistry) Register(name string, renderer RendererConstructor) {
|
||||
r.renderers[name] = renderer
|
||||
}
|
||||
@@ -46,7 +51,10 @@ func rendererForName(name string) Renderer {
|
||||
renderer := &Qwen3CoderRenderer{}
|
||||
return renderer
|
||||
case "qwen3-vl-instruct":
|
||||
renderer := &Qwen3VLRenderer{false}
|
||||
renderer := &Qwen3VLRenderer{isThinking: false, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
case "qwen3-vl-thinking":
|
||||
renderer := &Qwen3VLRenderer{isThinking: true, useImgTags: RenderImgTags}
|
||||
return renderer
|
||||
default:
|
||||
return nil
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
@@ -73,9 +75,10 @@ type JsonSchema struct {
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
Input any `json:"input"`
|
||||
Model string `json:"model"`
|
||||
Dimensions int `json:"dimensions,omitempty"`
|
||||
EncodingFormat string `json:"encoding_format,omitempty"` // "float" or "base64"
|
||||
}
|
||||
|
||||
type StreamOptions struct {
|
||||
@@ -181,9 +184,9 @@ type Model struct {
|
||||
}
|
||||
|
||||
type Embedding struct {
|
||||
Object string `json:"object"`
|
||||
Embedding []float32 `json:"embedding"`
|
||||
Index int `json:"index"`
|
||||
Object string `json:"object"`
|
||||
Embedding any `json:"embedding"` // Can be []float32 (float format) or string (base64 format)
|
||||
Index int `json:"index"`
|
||||
}
|
||||
|
||||
type ListCompletion struct {
|
||||
@@ -377,13 +380,21 @@ func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||
}
|
||||
|
||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
// encodingFormat can be "float", "base64", or empty (defaults to "float")
|
||||
func ToEmbeddingList(model string, r api.EmbedResponse, encodingFormat string) EmbeddingList {
|
||||
if r.Embeddings != nil {
|
||||
var data []Embedding
|
||||
for i, e := range r.Embeddings {
|
||||
var embedding any
|
||||
if strings.EqualFold(encodingFormat, "base64") {
|
||||
embedding = floatsToBase64(e)
|
||||
} else {
|
||||
embedding = e
|
||||
}
|
||||
|
||||
data = append(data, Embedding{
|
||||
Object: "embedding",
|
||||
Embedding: e,
|
||||
Embedding: embedding,
|
||||
Index: i,
|
||||
})
|
||||
}
|
||||
@@ -402,6 +413,13 @@ func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
return EmbeddingList{}
|
||||
}
|
||||
|
||||
// floatsToBase64 encodes a []float32 to a base64 string
|
||||
func floatsToBase64(floats []float32) string {
|
||||
var buf bytes.Buffer
|
||||
binary.Write(&buf, binary.LittleEndian, floats)
|
||||
return base64.StdEncoding.EncodeToString(buf.Bytes())
|
||||
}
|
||||
|
||||
// ToModel converts an api.ShowResponse to Model
|
||||
func ToModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
|
||||
139
openai/openai_encoding_format_test.go
Normal file
139
openai/openai_encoding_format_test.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestToEmbeddingList(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
embeddings [][]float32
|
||||
format string
|
||||
expectType string // "float" or "base64"
|
||||
expectBase64 []string
|
||||
expectCount int
|
||||
promptEval int
|
||||
}{
|
||||
{"float format", [][]float32{{0.1, -0.2, 0.3}}, "float", "float", nil, 1, 10},
|
||||
{"base64 format", [][]float32{{0.1, -0.2, 0.3}}, "base64", "base64", []string{"zczMPc3MTL6amZk+"}, 1, 5},
|
||||
{"default to float", [][]float32{{0.1, -0.2, 0.3}}, "", "float", nil, 1, 0},
|
||||
{"invalid defaults to float", [][]float32{{0.1, -0.2, 0.3}}, "invalid", "float", nil, 1, 0},
|
||||
{"multiple embeddings", [][]float32{{0.1, 0.2}, {0.3, 0.4}, {0.5, 0.6}}, "base64", "base64", []string{"zczMPc3MTD4=", "mpmZPs3MzD4=", "AAAAP5qZGT8="}, 3, 0},
|
||||
{"empty embeddings", nil, "float", "", nil, 0, 0},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
resp := api.EmbedResponse{
|
||||
Embeddings: tc.embeddings,
|
||||
PromptEvalCount: tc.promptEval,
|
||||
}
|
||||
|
||||
result := ToEmbeddingList("test-model", resp, tc.format)
|
||||
|
||||
if tc.expectCount == 0 {
|
||||
if len(result.Data) != 0 {
|
||||
t.Errorf("expected 0 embeddings, got %d", len(result.Data))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if len(result.Data) != tc.expectCount {
|
||||
t.Fatalf("expected %d embeddings, got %d", tc.expectCount, len(result.Data))
|
||||
}
|
||||
|
||||
if result.Model != "test-model" {
|
||||
t.Errorf("expected model 'test-model', got %q", result.Model)
|
||||
}
|
||||
|
||||
// Check type of first embedding
|
||||
switch tc.expectType {
|
||||
case "float":
|
||||
if _, ok := result.Data[0].Embedding.([]float32); !ok {
|
||||
t.Errorf("expected []float32, got %T", result.Data[0].Embedding)
|
||||
}
|
||||
case "base64":
|
||||
for i, data := range result.Data {
|
||||
embStr, ok := data.Embedding.(string)
|
||||
if !ok {
|
||||
t.Errorf("embedding %d: expected string, got %T", i, data.Embedding)
|
||||
continue
|
||||
}
|
||||
|
||||
// Verify it's valid base64
|
||||
if _, err := base64.StdEncoding.DecodeString(embStr); err != nil {
|
||||
t.Errorf("embedding %d: invalid base64: %v", i, err)
|
||||
}
|
||||
|
||||
// Compare against expected base64 string if provided
|
||||
if tc.expectBase64 != nil && i < len(tc.expectBase64) {
|
||||
if embStr != tc.expectBase64[i] {
|
||||
t.Errorf("embedding %d: expected base64 %q, got %q", i, tc.expectBase64[i], embStr)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check indices
|
||||
for i := range result.Data {
|
||||
if result.Data[i].Index != i {
|
||||
t.Errorf("embedding %d: expected index %d, got %d", i, i, result.Data[i].Index)
|
||||
}
|
||||
}
|
||||
|
||||
if tc.promptEval > 0 && result.Usage.PromptTokens != tc.promptEval {
|
||||
t.Errorf("expected %d prompt tokens, got %d", tc.promptEval, result.Usage.PromptTokens)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloatsToBase64(t *testing.T) {
|
||||
floats := []float32{0.1, -0.2, 0.3, -0.4, 0.5}
|
||||
|
||||
result := floatsToBase64(floats)
|
||||
|
||||
// Verify it's valid base64
|
||||
decoded, err := base64.StdEncoding.DecodeString(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
// Check length
|
||||
expectedBytes := len(floats) * 4
|
||||
if len(decoded) != expectedBytes {
|
||||
t.Errorf("expected %d bytes, got %d", expectedBytes, len(decoded))
|
||||
}
|
||||
|
||||
// Decode and verify values
|
||||
for i, expected := range floats {
|
||||
offset := i * 4
|
||||
bits := uint32(decoded[offset]) |
|
||||
uint32(decoded[offset+1])<<8 |
|
||||
uint32(decoded[offset+2])<<16 |
|
||||
uint32(decoded[offset+3])<<24
|
||||
decodedFloat := math.Float32frombits(bits)
|
||||
|
||||
if math.Abs(float64(decodedFloat-expected)) > 1e-6 {
|
||||
t.Errorf("float[%d]: expected %f, got %f", i, expected, decodedFloat)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestFloatsToBase64_EmptySlice(t *testing.T) {
|
||||
result := floatsToBase64([]float32{})
|
||||
|
||||
// Should return valid base64 for empty slice
|
||||
decoded, err := base64.StdEncoding.DecodeString(result)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode base64: %v", err)
|
||||
}
|
||||
|
||||
if len(decoded) != 0 {
|
||||
t.Errorf("expected 0 bytes, got %d", len(decoded))
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -799,7 +800,10 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, ti); err != nil {
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Calculate sha256 of file
|
||||
|
||||
@@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
defer s.mu.Unlock()
|
||||
|
||||
var batch *llama.Batch
|
||||
var numOutputs int
|
||||
|
||||
seqIdx := s.nextSeq - 1
|
||||
for range s.seqs {
|
||||
@@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
break
|
||||
}
|
||||
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
|
||||
output := i+1 == len(seq.inputs)
|
||||
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
|
||||
if output {
|
||||
numOutputs++
|
||||
}
|
||||
|
||||
seq.pendingInputs = append(seq.pendingInputs, input)
|
||||
seq.iBatch = batch.NumTokens() - 1
|
||||
}
|
||||
@@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
if numOutputs > 0 {
|
||||
s.lc.Synchronize()
|
||||
}
|
||||
|
||||
for i, seq := range s.seqs {
|
||||
if seq == nil {
|
||||
continue
|
||||
@@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
|
||||
// don't sample prompt processing
|
||||
if len(seq.inputs) != 0 {
|
||||
seq.processingDuration += time.Since(t)
|
||||
continue
|
||||
}
|
||||
|
||||
s.lc.Synchronize()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
@@ -697,8 +707,15 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
truncate: req.Truncate,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("Failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -741,7 +758,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
embedding := <-seq.embedding
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
Embedding: embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -946,8 +946,15 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{embedding: true})
|
||||
seq, err := s.NewSequence(req.Content, nil, NewSequenceParams{
|
||||
embedding: true,
|
||||
truncate: req.Truncate,
|
||||
})
|
||||
if err != nil {
|
||||
if errors.Is(err, errorInputTooLong) {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
http.Error(w, fmt.Sprintf("failed to create new sequence: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
@@ -988,7 +995,8 @@ func (s *Server) embeddings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(w).Encode(&llm.EmbeddingResponse{
|
||||
Embedding: <-seq.embedding,
|
||||
Embedding: <-seq.embedding,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -84,11 +84,11 @@ function buildCPU() {
|
||||
Remove-Item -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}"
|
||||
New-Item "${script:SRC_DIR}\dist\windows-${script:ARCH}\lib\ollama\" -ItemType Directory -ea 0
|
||||
|
||||
& cmake --fresh --preset CPU --install-prefix $script:DIST_DIR
|
||||
& cmake -B build\cpu --preset CPU --install-prefix $script:DIST_DIR
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset CPU --config Release --parallel $script:JOBS
|
||||
& cmake --build build\cpu --target ggml-cpu --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component CPU --strip
|
||||
& cmake --install build\cpu --component CPU --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
@@ -105,11 +105,11 @@ function buildCUDA11() {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V11")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
write-host "Building CUDA v11 backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake --fresh --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
|
||||
& cmake -B build\cuda_v11 --preset "CUDA 11" -T cuda="$cuda" -DCMAKE_CUDA_COMPILER="$cuda\bin\nvcc.exe" -G "Visual Studio 16 2019" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v11"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 11" --config Release --parallel $script:JOBS
|
||||
& cmake --build build\cuda_v11 --target ggml-cuda --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
& cmake --install build\cuda_v11 --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
@@ -124,11 +124,11 @@ function buildCUDA12() {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V12_8")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
write-host "Building CUDA v12 backend libraries $cuda"
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
& cmake --fresh --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
|
||||
& cmake -B build\cuda_v12 --preset "CUDA 12" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v12"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 12" --config Release --parallel $script:JOBS
|
||||
& cmake --build build\cuda_v12 --target ggml-cuda --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
& cmake --install build\cuda_v12 --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
@@ -143,11 +143,11 @@ function buildCUDA13() {
|
||||
$hashEnv.Keys | foreach { if ($_.Contains("CUDA_PATH_V13")) { $x=$hashEnv[$_]; if (test-path -literalpath "$x\bin\nvcc.exe" ) { $cuda=$x} }}
|
||||
$env:CUDAToolkit_ROOT=$cuda
|
||||
write-host "Building CUDA v13 backend libraries $cuda"
|
||||
& cmake --fresh --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
|
||||
& cmake -B build\cuda_v13 --preset "CUDA 13" -T cuda="$cuda" --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="cuda_v13"
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --build --preset "CUDA 13" --config Release --parallel $script:JOBS
|
||||
& cmake --build build\cuda_v13 --target ggml-cuda --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "CUDA" --strip
|
||||
& cmake --install build\cuda_v13 --component "CUDA" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
@@ -165,7 +165,7 @@ function buildROCm() {
|
||||
$env:HIPCXX="${env:HIP_PATH}\bin\clang++.exe"
|
||||
$env:HIP_PLATFORM="amd"
|
||||
$env:CMAKE_PREFIX_PATH="${env:HIP_PATH}"
|
||||
& cmake --fresh --preset "ROCm 6" -G Ninja -DOLLAMA_RUNNER_DIR="rocm" `
|
||||
& cmake --fresh -B build\rocm --preset "ROCm 6" -G Ninja -DOLLAMA_RUNNER_DIR="rocm" `
|
||||
-DCMAKE_C_COMPILER=clang `
|
||||
-DCMAKE_CXX_COMPILER=clang++ `
|
||||
-DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" `
|
||||
@@ -175,9 +175,9 @@ function buildROCm() {
|
||||
$env:HIPCXX=""
|
||||
$env:HIP_PLATFORM=""
|
||||
$env:CMAKE_PREFIX_PATH=""
|
||||
& cmake --build --preset "ROCm 6" --config Release --parallel $script:JOBS
|
||||
& cmake --build build\rocm --target ggml-hip --config Release --parallel $script:JOBS
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "HIP" --strip
|
||||
& cmake --install build\rocm --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
}
|
||||
|
||||
@@ -119,6 +119,27 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
|
||||
if err == nil && !remote && (config.Renderer == "" || config.Parser == "") {
|
||||
manifest, mErr := ParseNamedManifest(fromName)
|
||||
if mErr == nil && manifest.Config.Digest != "" {
|
||||
configPath, pErr := GetBlobsPath(manifest.Config.Digest)
|
||||
if pErr == nil {
|
||||
if cfgFile, fErr := os.Open(configPath); fErr == nil {
|
||||
var baseConfig ConfigV2
|
||||
if decErr := json.NewDecoder(cfgFile).Decode(&baseConfig); decErr == nil {
|
||||
if config.Renderer == "" {
|
||||
config.Renderer = baseConfig.Renderer
|
||||
}
|
||||
if config.Parser == "" {
|
||||
config.Parser = baseConfig.Parser
|
||||
}
|
||||
}
|
||||
cfgFile.Close()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if r.Files != nil {
|
||||
baseLayers, err = convertModelFromFiles(r.Files, baseLayers, false, fn)
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"os/signal"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
@@ -39,6 +40,7 @@ import (
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/model/renderers"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -91,6 +93,9 @@ func init() {
|
||||
}
|
||||
|
||||
gin.SetMode(mode)
|
||||
|
||||
// Tell renderers to use [img] tags
|
||||
renderers.RenderImgTags = true
|
||||
}
|
||||
|
||||
var (
|
||||
@@ -285,6 +290,12 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
contentType := "application/json; charset=utf-8"
|
||||
if req.Stream != nil && *req.Stream {
|
||||
contentType = "application/x-ndjson"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@@ -649,7 +660,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
r, _, _, err := s.scheduleRunner(c.Request.Context(), name.String(), []model.Capability{}, req.Options, req.KeepAlive)
|
||||
if err != nil {
|
||||
handleScheduleError(c, req.Model, err)
|
||||
return
|
||||
@@ -662,61 +673,12 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
kvData, _, err := getModelData(m.ModelPath, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
|
||||
if len(tokens) > ctxLen {
|
||||
if !truncate {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
if bos := kvData.Uint("tokenizer.ggml.bos_token_id"); tokens[0] != int(bos) && kvData.Bool("add_bos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
if eos := kvData.Uint("tokenizer.ggml.eos_token_id"); tokens[len(tokens)-1] != int(eos) && kvData.Bool("add_eos_token", true) {
|
||||
ctxLen--
|
||||
}
|
||||
|
||||
slog.Info("", "ctxLen", ctxLen, "tokenCount", len(tokens))
|
||||
if ctxLen <= 0 {
|
||||
// return error if the truncated input would be empty or just special tokens
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "input after truncation exceeds maximum context length"})
|
||||
return
|
||||
}
|
||||
|
||||
tokens = tokens[:ctxLen]
|
||||
|
||||
s, err = r.Detokenize(c.Request.Context(), tokens)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
var totalTokens uint64
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
embedding, tokenCount, err := r.Embedding(c.Request.Context(), text, truncate)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -726,12 +688,18 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
embedding = normalize(embedding[:req.Dimensions])
|
||||
}
|
||||
embeddings[i] = embedding
|
||||
atomic.AddUint64(&totalTokens, uint64(tokenCount))
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
if err := g.Wait(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
var serr api.StatusError
|
||||
if errors.As(err, &serr) {
|
||||
c.AbortWithStatusJSON(serr.StatusCode, gin.H{"error": strings.TrimSpace(serr.ErrorMessage)})
|
||||
} else {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -740,7 +708,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: count,
|
||||
PromptEvalCount: int(totalTokens),
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -786,7 +754,7 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, _, err := r.Embedding(c.Request.Context(), req.Prompt, true)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": strings.TrimSpace(err.Error())})
|
||||
return
|
||||
@@ -1870,10 +1838,14 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
req.Options = map[string]any{}
|
||||
}
|
||||
|
||||
msgs := append(m.Messages, req.Messages...)
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
var msgs []api.Message
|
||||
if len(req.Messages) > 0 {
|
||||
msgs = append(m.Messages, req.Messages...)
|
||||
if req.Messages[0].Role != "system" && m.System != "" {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
}
|
||||
|
||||
msgs = filterThinkTags(msgs, m)
|
||||
req.Messages = msgs
|
||||
|
||||
@@ -1924,6 +1896,12 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
contentType := "application/json; charset=utf-8"
|
||||
if req.Stream != nil && *req.Stream {
|
||||
contentType = "application/x-ndjson"
|
||||
}
|
||||
c.Header("Content-Type", contentType)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"maps"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
@@ -17,6 +18,8 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
gocmp "github.com/google/go-cmp/cmp"
|
||||
gocmpopts "github.com/google/go-cmp/cmp/cmpopts"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -38,7 +41,10 @@ func createBinFile(t *testing.T, kv map[string]any, ti []*ggml.Tensor) (string,
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := ggml.WriteGGUF(f, kv, ti); err != nil {
|
||||
base := map[string]any{"general.architecture": "test"}
|
||||
maps.Copy(base, kv)
|
||||
|
||||
if err := ggml.WriteGGUF(f, base, ti); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Calculate sha256 of file
|
||||
@@ -102,8 +108,8 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(actual, expect) {
|
||||
t.Fatalf("expected slices to be equal %v", actual)
|
||||
if diff := gocmp.Diff(expect, actual, gocmpopts.SortSlices(strings.Compare), gocmpopts.EquateEmpty()); diff != "" {
|
||||
t.Errorf("file exists mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -133,8 +139,8 @@ func TestCreateFromBin(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -177,11 +183,77 @@ func TestCreateFromModel(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateFromModelInheritsRendererParser(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
const (
|
||||
renderer = "custom-renderer"
|
||||
parser = "custom-parser"
|
||||
)
|
||||
|
||||
_, digest := createBinFile(t, nil, nil)
|
||||
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "base",
|
||||
Files: map[string]string{"base.gguf": digest},
|
||||
Renderer: renderer,
|
||||
Parser: parser,
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Name: "child",
|
||||
From: "base",
|
||||
Stream: &stream,
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
manifest, err := ParseNamedManifest(model.ParseName("child"))
|
||||
if err != nil {
|
||||
t.Fatalf("parse manifest: %v", err)
|
||||
}
|
||||
if manifest.Config.Digest == "" {
|
||||
t.Fatalf("unexpected empty config digest for child manifest")
|
||||
}
|
||||
|
||||
configPath, err := GetBlobsPath(manifest.Config.Digest)
|
||||
if err != nil {
|
||||
t.Fatalf("config blob path: %v", err)
|
||||
}
|
||||
|
||||
cfgFile, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
t.Fatalf("open config blob: %v", err)
|
||||
}
|
||||
defer cfgFile.Close()
|
||||
|
||||
var cfg ConfigV2
|
||||
if err := json.NewDecoder(cfgFile).Decode(&cfg); err != nil {
|
||||
t.Fatalf("decode config: %v", err)
|
||||
}
|
||||
|
||||
if cfg.Renderer != renderer {
|
||||
t.Fatalf("expected renderer %q, got %q", renderer, cfg.Renderer)
|
||||
}
|
||||
if cfg.Parser != parser {
|
||||
t.Fatalf("expected parser %q, got %q", parser, cfg.Parser)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateRemovesLayers(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -206,9 +278,9 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"),
|
||||
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
|
||||
filepath.Join(p, "blobs", "sha256-f6e7e4b28e0b1d0c635f2d465bd248c5387c3e75b61a48c4374192b26d832a56"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
@@ -227,8 +299,8 @@ func TestCreateRemovesLayers(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
}
|
||||
@@ -257,8 +329,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-0a666d113e8e0a3d27e9c7bd136a0bdfb6241037db50729d81568451ebfdbde8"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
|
||||
})
|
||||
|
||||
@@ -278,8 +350,8 @@ func TestCreateUnsetsSystem(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -312,8 +384,8 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-6d6e36c1f90fc7deefc33a7300aa21ad4b67c506e33ecdeddfafa98147e60bbf"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
|
||||
// in order to merge parameters, the second model must be created FROM the first
|
||||
@@ -354,9 +426,9 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-6d6e36c1f90fc7deefc33a7300aa21ad4b67c506e33ecdeddfafa98147e60bbf"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-bbdce269dabe013033632238b4b2d1e02fac2f97787c5e895f4da84e09cccd5d"),
|
||||
filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"),
|
||||
})
|
||||
|
||||
@@ -398,9 +470,9 @@ func TestCreateMergeParameters(t *testing.T) {
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-6d6e36c1f90fc7deefc33a7300aa21ad4b67c506e33ecdeddfafa98147e60bbf"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-9443591d14be23c1e33d101934d76ad03bdb0715fe0879e8b0d1819e7bb063dd"),
|
||||
})
|
||||
|
||||
actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"))
|
||||
@@ -456,8 +528,8 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-c84aee28f2af350596f674de51d2a802ea782653ef2930a21d48bd43d5cd5317"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
@@ -491,11 +563,11 @@ func TestCreateReplacesMessages(t *testing.T) {
|
||||
|
||||
// Old layers will not have been pruned
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-09cfac3e6a637e25cb41aa85c24c110dc17ba89634de7df141b564dd2da4168b"),
|
||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"),
|
||||
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||
filepath.Join(p, "blobs", "sha256-f4e2c3690efef1b4b63ba1e1b2744ffeb6a7438a0110b86596069f6d9999c80b"),
|
||||
filepath.Join(p, "blobs", "sha256-c84aee28f2af350596f674de51d2a802ea782653ef2930a21d48bd43d5cd5317"),
|
||||
})
|
||||
|
||||
type message struct {
|
||||
@@ -550,9 +622,9 @@ func TestCreateTemplateSystem(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"),
|
||||
filepath.Join(p, "blobs", "sha256-0a04d979734167da3b80811a1874d734697f366a689f3912589b99d2e86e7ad1"),
|
||||
filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
@@ -714,8 +786,8 @@ func TestCreateLicenses(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
||||
filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-a762f214df0d96c9a7b82f96da98d99ceb2776c88e3ea7ffa09d1e5835516ec6"),
|
||||
filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"),
|
||||
})
|
||||
|
||||
@@ -761,9 +833,9 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
|
||||
filepath.Join(p, "blobs", "sha256-3322a0c650c758b7386ff55629d27d07c07b6c3d3515e259dc3e5598c41e9f4e"),
|
||||
filepath.Join(p, "blobs", "sha256-35360843d0c84fb1506952a131bbef13cd2bb4a541251f22535170c05b56e672"),
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
filepath.Join(p, "blobs", "sha256-de3959f841e9ef6b4b6255fa41cb9e0a45da89c3066aa72bdd07a4747f848990"),
|
||||
filepath.Join(p, "blobs", "sha256-a56c12acca8068cb6c335e237da6643e8a802a92959a63ad5bd17828e3b5e9b0"),
|
||||
})
|
||||
})
|
||||
|
||||
@@ -780,8 +852,8 @@ func TestCreateDetectTemplate(t *testing.T) {
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
@@ -30,16 +30,16 @@ func TestGenerateDebugRenderOnly(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
@@ -223,16 +223,16 @@ func TestChatDebugRenderOnly(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
|
||||
@@ -47,9 +47,9 @@ func TestDelete(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
filepath.Join(p, "blobs", "sha256-6bcdb8859d417753645538d7bbfbd7ca91a3f0c191aef5379c53c05e86b669dd"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
@@ -64,8 +64,8 @@ func TestDelete(t *testing.T) {
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-136bf7c76bac2ec09d6617885507d37829e04b41acc47687d45e512b544e893a"),
|
||||
filepath.Join(p, "blobs", "sha256-89a2116c3a82d6a97f59f748d86ed4417214353fd178ee54df418fde32495fad"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
// TestGenerateWithBuiltinRenderer tests that api/generate uses built-in renderers
|
||||
@@ -35,16 +35,16 @@ func TestGenerateWithBuiltinRenderer(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
@@ -219,16 +219,16 @@ func TestGenerateWithDebugRenderOnly(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
@@ -15,9 +17,9 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type mockRunner struct {
|
||||
@@ -46,12 +48,92 @@ func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error
|
||||
return
|
||||
}
|
||||
|
||||
func newMockServer(mock *mockRunner) func(discover.GpuInfoList, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ discover.GpuInfoList, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
func newMockServer(mock *mockRunner) func(ml.SystemInfo, []ml.DeviceInfo, string, *ggml.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
|
||||
return func(_ ml.SystemInfo, _ []ml.DeviceInfo, _ string, _ *ggml.GGML, _, _ []string, _ api.Options, _ int) (llm.LlamaServer, error) {
|
||||
return mock, nil
|
||||
}
|
||||
}
|
||||
|
||||
func TestGenerateChatRemote(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
rs := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != http.MethodPost {
|
||||
t.Errorf("Expected POST request, got %s", r.Method)
|
||||
}
|
||||
if r.URL.Path != "/api/chat" {
|
||||
t.Errorf("Expected path '/api/chat', got %s", r.URL.Path)
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp := api.ChatResponse{
|
||||
Model: "test",
|
||||
Done: true,
|
||||
DoneReason: "load",
|
||||
}
|
||||
if err := json.NewEncoder(w).Encode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}))
|
||||
defer rs.Close()
|
||||
|
||||
p, err := url.Parse(rs.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_REMOTES", p.Hostname())
|
||||
s := Server{}
|
||||
w := createRequest(t, s.CreateHandler, api.CreateRequest{
|
||||
Model: "test-cloud",
|
||||
RemoteHost: rs.URL,
|
||||
From: "test",
|
||||
Info: map[string]any{
|
||||
"capabilities": []string{"completion", "thinking"},
|
||||
},
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
t.Run("missing messages", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-cloud",
|
||||
})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var actual api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if actual.Model != "test-cloud" {
|
||||
t.Errorf("expected model test-cloud, got %s", actual.Model)
|
||||
}
|
||||
|
||||
if actual.RemoteModel != "test" {
|
||||
t.Errorf("expected remote model test, got %s", actual.RemoteModel)
|
||||
}
|
||||
|
||||
if actual.RemoteHost != rs.URL {
|
||||
t.Errorf("expected remote host '%s', got %s", rs.URL, actual.RemoteHost)
|
||||
}
|
||||
|
||||
if !actual.Done {
|
||||
t.Errorf("expected done true, got false")
|
||||
}
|
||||
|
||||
if actual.DoneReason != "load" {
|
||||
t.Errorf("expected done reason load, got %s", actual.DoneReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateChat(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
@@ -68,16 +150,16 @@ func TestGenerateChat(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
@@ -679,16 +761,16 @@ func TestGenerate(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
// add small delay to simulate loading
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{
|
||||
@@ -1104,16 +1186,16 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
|
||||
s := &Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
time.Sleep(time.Millisecond)
|
||||
req.successCh <- &runnerRef{llama: mock}
|
||||
return false
|
||||
|
||||
@@ -14,9 +14,9 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
func getTestTools() []api.Tool {
|
||||
@@ -268,16 +268,16 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
@@ -419,16 +419,16 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 100 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
@@ -601,16 +601,16 @@ func TestChatHarmonyParserStreaming(t *testing.T) {
|
||||
|
||||
s := Server{
|
||||
sched: &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getCpuFn: getCpuFn,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool {
|
||||
pendingReqCh: make(chan *LlmRequest, 1),
|
||||
finishedReqCh: make(chan *LlmRequest, 1),
|
||||
expiredCh: make(chan *runnerRef, 1),
|
||||
unloadedCh: make(chan any, 1),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: newMockServer(&mock),
|
||||
getGpuFn: getGpuFn,
|
||||
getSystemInfoFn: getSystemInfoFn,
|
||||
waitForRecovery: 250 * time.Millisecond,
|
||||
loadFn: func(req *LlmRequest, _ *ggml.GGML, _ ml.SystemInfo, _ []ml.DeviceInfo, _ bool) bool {
|
||||
req.successCh <- &runnerRef{
|
||||
llama: &mock,
|
||||
}
|
||||
|
||||
141
server/sched.go
141
server/sched.go
@@ -5,12 +5,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
@@ -52,11 +49,11 @@ type Scheduler struct {
|
||||
activeLoading llm.LlamaServer
|
||||
loaded map[string]*runnerRef
|
||||
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool
|
||||
newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList
|
||||
getCpuFn func() discover.GpuInfo
|
||||
reschedDelay time.Duration
|
||||
loadFn func(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool
|
||||
newServerFn func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error)
|
||||
getGpuFn func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo
|
||||
getSystemInfoFn func() ml.SystemInfo
|
||||
waitForRecovery time.Duration
|
||||
}
|
||||
|
||||
// Default automatic value for number of models we allow per GPU
|
||||
@@ -69,15 +66,15 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
maxQueue := envconfig.MaxQueue()
|
||||
sched := &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan any, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: discover.GetGPUInfo,
|
||||
getCpuFn: discover.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan any, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: discover.GPUDevices,
|
||||
getSystemInfoFn: discover.GetSystemInfo,
|
||||
waitForRecovery: 5 * time.Second,
|
||||
}
|
||||
sched.loadFn = sched.load
|
||||
return sched
|
||||
@@ -131,6 +128,8 @@ func (s *Scheduler) Run(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (s *Scheduler) processPending(ctx context.Context) {
|
||||
maxRunners := envconfig.MaxRunners()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -150,7 +149,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
s.loadedMu.Lock()
|
||||
runner := s.loaded[pending.model.ModelPath]
|
||||
loadedCount := len(s.loaded)
|
||||
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
@@ -165,39 +164,29 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
pending.useLoadedRunner(runner, s.finishedReqCh)
|
||||
break
|
||||
}
|
||||
} else if envconfig.MaxRunners() > 0 && loadedCount >= int(envconfig.MaxRunners()) {
|
||||
} else if maxRunners > 0 && loadedCount >= int(maxRunners) {
|
||||
slog.Debug("max runners achieved, unloading one to make room", "runner_count", loadedCount)
|
||||
runnerToExpire = s.findRunnerToUnload()
|
||||
} else {
|
||||
// Either no models are loaded or below envconfig.MaxRunners
|
||||
// Get a refreshed GPU list
|
||||
var gpus discover.GpuInfoList
|
||||
var gpus []ml.DeviceInfo
|
||||
if pending.opts.NumGPU == 0 {
|
||||
gpus = discover.GpuInfoList{s.getCpuFn()}
|
||||
gpus = []ml.DeviceInfo{}
|
||||
} else {
|
||||
gpus = s.getGpuFn(ctx, runnersSnapshot)
|
||||
}
|
||||
|
||||
if envconfig.MaxRunners() <= 0 {
|
||||
// No user specified MaxRunners, so figure out what automatic setting to use
|
||||
// If all GPUs have reliable free memory reporting, defaultModelsPerGPU * the number of GPUs
|
||||
// if any GPU has unreliable free memory reporting, 1x the number of GPUs
|
||||
allReliable := true
|
||||
for _, gpu := range gpus {
|
||||
if gpu.UnreliableFreeMemory {
|
||||
allReliable = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if allReliable {
|
||||
// HACK
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus))
|
||||
systemInfo := s.getSystemInfoFn()
|
||||
if maxRunners <= 0 {
|
||||
// No user specified MaxRunners, so figure out what automatic setting to use for the next load attempt
|
||||
if pending.opts.NumGPU == 0 {
|
||||
// Need to get actual GPU list to set the correct default max models
|
||||
g := s.getGpuFn(ctx, runnersSnapshot)
|
||||
maxRunners = uint(defaultModelsPerGPU * max(len(g), 1))
|
||||
} else {
|
||||
// HACK
|
||||
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))
|
||||
slog.Info("one or more GPUs detected that are unable to accurately report free memory - disabling default concurrency")
|
||||
maxRunners = uint(defaultModelsPerGPU * max(len(gpus), 1))
|
||||
}
|
||||
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", maxRunners, "gpu_count", len(gpus))
|
||||
}
|
||||
|
||||
// Load model for fitting
|
||||
@@ -213,14 +202,14 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
if loadedCount == 0 {
|
||||
// No models loaded. Load the model but prefer the best fit.
|
||||
slog.Debug("loading first model", "model", pending.model.ModelPath)
|
||||
s.loadFn(pending, ggml, gpus, false)
|
||||
s.loadFn(pending, ggml, systemInfo, gpus, false)
|
||||
break
|
||||
}
|
||||
|
||||
// More than one loaded model, so we have to see if the
|
||||
// new one fits
|
||||
|
||||
needEvict := s.loadFn(pending, ggml, gpus, true)
|
||||
needEvict := s.loadFn(pending, ggml, systemInfo, gpus, true)
|
||||
if !needEvict {
|
||||
slog.Debug("new model fits with existing models, loading")
|
||||
break
|
||||
@@ -351,7 +340,7 @@ func (s *Scheduler) processCompleted(ctx context.Context) {
|
||||
runner.refMu.Unlock()
|
||||
} else {
|
||||
slog.Debug("starting background wait for VRAM recovery", "runner", runner)
|
||||
runnersSnapshot := make([]discover.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
runnersSnapshot := make([]ml.FilteredRunnerDiscovery, 0, len(s.loaded))
|
||||
for _, r := range s.loaded {
|
||||
runnersSnapshot = append(runnersSnapshot, r)
|
||||
}
|
||||
@@ -393,7 +382,7 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm
|
||||
|
||||
// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs
|
||||
// (if any). Returns whether the scheduler needs to evict a model to make this one fit.
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool {
|
||||
func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) bool {
|
||||
numParallel := max(int(envconfig.NumParallel()), 1)
|
||||
|
||||
// Embedding models should always be loaded with parallel=1
|
||||
@@ -418,7 +407,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
|
||||
if llama == nil {
|
||||
var err error
|
||||
llama, err = s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
llama, err = s.newServerFn(systemInfo, gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
@@ -441,9 +430,16 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
|
||||
s.loadedMu.Unlock()
|
||||
|
||||
gpuIDs, err := llama.Load(req.ctx, gpus, requireFull)
|
||||
gpuIDs, err := llama.Load(req.ctx, systemInfo, gpus, requireFull)
|
||||
if err != nil {
|
||||
if errors.Is(err, llm.ErrLoadRequiredFull) {
|
||||
if !requireFull {
|
||||
// No other models loaded, yet we still don't fit, so report an error
|
||||
slog.Info("model is too large for system memory", "requireFull", requireFull)
|
||||
s.activeLoading.Close()
|
||||
s.activeLoading = nil
|
||||
req.errCh <- err
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -454,6 +450,20 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
return false
|
||||
}
|
||||
|
||||
// Determine if we have discrete GPUs which we should monitor VRAM usage on during shutdown
|
||||
discreteGPUs := false
|
||||
iGPUScan:
|
||||
for _, devid := range gpuIDs {
|
||||
for _, dev := range gpus {
|
||||
if dev.DeviceID == devid {
|
||||
if !dev.Integrated {
|
||||
discreteGPUs = true
|
||||
break iGPUScan
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
runner := &runnerRef{
|
||||
model: req.model,
|
||||
modelPath: req.model.ModelPath,
|
||||
@@ -461,6 +471,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
Options: &req.opts,
|
||||
sessionDuration: sessionDuration,
|
||||
gpus: gpuIDs,
|
||||
discreteGPUs: discreteGPUs,
|
||||
vramSize: llama.VRAMSize(),
|
||||
totalSize: llama.TotalSize(),
|
||||
loading: true,
|
||||
@@ -508,7 +519,10 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) {
|
||||
func (s *Scheduler) updateFreeSpace(allGpus []ml.DeviceInfo) {
|
||||
if len(allGpus) == 0 {
|
||||
return
|
||||
}
|
||||
predMap := map[ml.DeviceID]uint64{} // Sum up the total predicted usage per GPU for all runners
|
||||
s.loadedMu.Lock()
|
||||
runners := make([]*runnerRef, 0, len(s.loaded))
|
||||
@@ -552,12 +566,13 @@ type runnerRef struct {
|
||||
refMu sync.Mutex
|
||||
refCount uint // prevent unloading if > 0
|
||||
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
llama llm.LlamaServer
|
||||
pid int
|
||||
loading bool // True only during initial load, then false forever
|
||||
gpus []ml.DeviceID // Recorded at time of provisioning
|
||||
discreteGPUs bool // True if all devices are discrete GPUs - used to skip VRAM recovery check for iGPUs
|
||||
vramSize uint64
|
||||
totalSize uint64
|
||||
|
||||
sessionDuration time.Duration
|
||||
expireTimer *time.Timer
|
||||
@@ -625,14 +640,12 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool
|
||||
// a before and after GPU memory allocation. The returned channel
|
||||
// will be notified when we're done waiting, or have timed out and should
|
||||
// proceed anyway
|
||||
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.FilteredRunnerDiscovery) chan any {
|
||||
func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []ml.FilteredRunnerDiscovery) chan any {
|
||||
finished := make(chan any, 1)
|
||||
|
||||
// CPU or Metal don't need checking, so no waiting required
|
||||
// windows can page VRAM, only cuda currently can report accurate used vram usage
|
||||
if len(runner.gpus) == 0 ||
|
||||
(len(runner.gpus) == 1 && (runner.gpus[0].Library == "cpu" || runner.gpus[0].Library == "Metal")) ||
|
||||
(runtime.GOOS == "windows" && runner.gpus[0].Library != "CUDA") {
|
||||
// CPU, Metal and iGPUs don't need checking, so no waiting required
|
||||
if len(runner.gpus) == 0 || !runner.discreteGPUs ||
|
||||
(len(runner.gpus) == 1 && runner.gpus[0].Library == "Metal") {
|
||||
finished <- struct{}{}
|
||||
slog.Debug("no need to wait for VRAM recovery", "runner", runner)
|
||||
return finished
|
||||
@@ -650,8 +663,8 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
|
||||
freeMemoryNow := freeMemoryBefore
|
||||
|
||||
go func() {
|
||||
// typical convergence is 0.5-1.5s - If it takes more than 5 seconds to discover and converge, let the scheduler estimate VRAM usage
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
// typical convergence is 0.5-1.5s - If it takes too long to discover and converge, let the scheduler estimate VRAM usage
|
||||
ctx, cancel := context.WithTimeout(context.Background(), s.waitForRecovery)
|
||||
defer cancel()
|
||||
ticker := time.NewTicker(250 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
@@ -666,7 +679,11 @@ func (s *Scheduler) waitForVRAMRecovery(runner *runnerRef, runners []discover.Fi
|
||||
totalMemoryNow += gpu.TotalMemory
|
||||
freeMemoryNow += gpu.FreeMemory
|
||||
}
|
||||
logutil.Trace("gpu VRAM convergence", "percent", int(max(float32(freeMemoryNow-freeMemoryBefore), 0.0)/float32(runner.vramSize)*100))
|
||||
if freeMemoryNow > freeMemoryBefore {
|
||||
logutil.Trace("gpu VRAM convergence", "percent", int(float32(freeMemoryNow-freeMemoryBefore)/float32(runner.vramSize)*100))
|
||||
} else {
|
||||
logutil.Trace("gpu VRAM convergence", "percent", 0)
|
||||
}
|
||||
// If we're within ~75% of the estimated memory usage recovered, bail out
|
||||
if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.75 {
|
||||
slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "free_before", format.HumanBytes2(freeMemoryBefore), "free_now", format.HumanBytes2(freeMemoryNow), "runner", runner)
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/discover"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
@@ -26,7 +25,7 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func TestInitScheduler(t *testing.T) {
|
||||
func TestSchedInit(t *testing.T) {
|
||||
ctx, done := context.WithCancel(t.Context())
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
@@ -35,10 +34,11 @@ func TestInitScheduler(t *testing.T) {
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestLoad(t *testing.T) {
|
||||
func TestSchedLoad(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
var f *ggml.GGML // value not used in tests
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
@@ -49,11 +49,12 @@ func TestLoad(t *testing.T) {
|
||||
sessionDuration: &api.Duration{Duration: 2 * time.Second},
|
||||
}
|
||||
// Fail to load model first
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
return nil, errors.New("something failed to load model blah")
|
||||
}
|
||||
gpus := discover.GpuInfoList{}
|
||||
s.load(req, f, gpus, false)
|
||||
gpus := []ml.DeviceInfo{}
|
||||
systemInfo := ml.SystemInfo{}
|
||||
s.load(req, f, systemInfo, gpus, false)
|
||||
require.Empty(t, req.successCh)
|
||||
require.Len(t, req.errCh, 1)
|
||||
s.loadedMu.Lock()
|
||||
@@ -63,11 +64,11 @@ func TestLoad(t *testing.T) {
|
||||
require.Contains(t, err.Error(), "this model may be incompatible")
|
||||
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, f, gpus, false)
|
||||
s.load(req, f, systemInfo, gpus, false)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.NoError(t, err)
|
||||
@@ -81,7 +82,7 @@ func TestLoad(t *testing.T) {
|
||||
|
||||
req.model.ModelPath = "dummy_model_path"
|
||||
server.waitResp = errors.New("wait failure")
|
||||
s.load(req, f, gpus, false)
|
||||
s.load(req, f, systemInfo, gpus, false)
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
require.Contains(t, err.Error(), "wait failure")
|
||||
@@ -105,7 +106,7 @@ type reqBundle struct {
|
||||
f *ggml.GGML
|
||||
}
|
||||
|
||||
func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
func (scenario *reqBundle) newServer(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
scenario.srv.modelPath = model
|
||||
return scenario.srv, nil
|
||||
}
|
||||
@@ -151,28 +152,29 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vra
|
||||
return b
|
||||
}
|
||||
|
||||
func getGpuFn(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
func getGpuFn(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
slog.Info("test getGpuFn called", "runners", runners)
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g := ml.DeviceInfo{DeviceID: ml.DeviceID{Library: "Metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
return []ml.DeviceInfo{g}
|
||||
}
|
||||
|
||||
func getCpuFn() discover.GpuInfo {
|
||||
slog.Info("test getCpuFn called")
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "cpu"}}
|
||||
g.TotalMemory = 32 * format.GigaByte
|
||||
g.FreeMemory = 26 * format.GigaByte
|
||||
return g
|
||||
func getSystemInfoFn() ml.SystemInfo {
|
||||
slog.Info("test getSystemInfoFn called")
|
||||
return ml.SystemInfo{
|
||||
TotalMemory: 32 * format.GigaByte,
|
||||
FreeMemory: 26 * format.GigaByte,
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
func TestSchedRequestsSameModelSameRequest(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0}, nil)
|
||||
b.req.model = a.req.model
|
||||
@@ -210,12 +212,13 @@ func TestRequestsSameModelSameRequest(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
func TestSchedRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 5000*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond}, nil)
|
||||
tmpModel := *a.req.model
|
||||
@@ -248,12 +251,12 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
a.ctxDone()
|
||||
// Report recovered VRAM usage
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
slog.Info("XXX altered getGpuFn called")
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
s.getGpuFn = func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
slog.Info("altered getGpuFn called")
|
||||
g := ml.DeviceInfo{DeviceID: ml.DeviceID{Library: "Metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 24 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
return []ml.DeviceInfo{g}
|
||||
}
|
||||
select {
|
||||
case resp := <-b.req.successCh:
|
||||
@@ -267,26 +270,27 @@ func TestRequestsSimpleReloadSameModel(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
func TestSchedRequestsMultipleLoadedModels(t *testing.T) {
|
||||
slog.Info("TestRequestsMultipleLoadedModels")
|
||||
ctx, done := context.WithTimeout(t.Context(), 1000*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = getGpuFn // 1 metal GPU
|
||||
s.getCpuFn = getCpuFn // 1 CPU
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.getGpuFn = getGpuFn // 1 Metal GPU
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
|
||||
// Multiple loaded models
|
||||
a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 1 * format.GigaByte})
|
||||
a := newScenarioRequest(t, ctx, "model-a-1g-gpu", 1*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "Metal"}: 1 * format.GigaByte})
|
||||
a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 10 * format.GigaByte})
|
||||
b := newScenarioRequest(t, ctx, "model-b-10g-gpu", 10*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "Metal"}: 10 * format.GigaByte})
|
||||
b.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
|
||||
c := newScenarioRequest(t, ctx, "model-c-10g-cpu", 10*format.GigaByte, nil, nil /* No GPU load */)
|
||||
c.req.opts.NumGPU = 0 // CPU load, will be allowed
|
||||
b.req.sessionDuration = &api.Duration{Duration: 10 * time.Millisecond} // longer than b to cause the scheduler to favor unloading b over c
|
||||
d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "metal"}: 13 * format.GigaByte}) // Needs prior unloaded
|
||||
d := newScenarioRequest(t, ctx, "model-d-10g-gpu", 13*format.GigaByte, nil, map[ml.DeviceID]uint64{{Library: "Metal"}: 13 * format.GigaByte}) // Needs prior unloaded
|
||||
|
||||
t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1")
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
slog.Info("Loading A")
|
||||
s.pendingReqCh <- a.req
|
||||
s.Run(ctx)
|
||||
select {
|
||||
@@ -305,7 +309,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
|
||||
t.Setenv("OLLAMA_MAX_LOADED_MODELS", "0")
|
||||
s.newServerFn = b.newServer
|
||||
slog.Info("b")
|
||||
slog.Info("Loading B")
|
||||
s.pendingReqCh <- b.req
|
||||
select {
|
||||
case resp := <-b.req.successCh:
|
||||
@@ -323,7 +327,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
|
||||
// This is a CPU load with NumGPU = 0 so it should load
|
||||
s.newServerFn = c.newServer
|
||||
slog.Info("c")
|
||||
slog.Info("Loading C")
|
||||
s.pendingReqCh <- c.req
|
||||
select {
|
||||
case resp := <-c.req.successCh:
|
||||
@@ -333,6 +337,7 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
case err := <-c.req.errCh:
|
||||
t.Fatal(err.Error())
|
||||
case <-ctx.Done():
|
||||
slog.Info("FAIL: scheduler state", "s.loaded", s.loaded)
|
||||
t.Fatal("timeout")
|
||||
}
|
||||
s.loadedMu.Lock()
|
||||
@@ -357,11 +362,11 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
|
||||
b.ctxDone()
|
||||
// Report recovered VRAM usage so scheduler will finish waiting and unload
|
||||
time.Sleep(1 * time.Millisecond)
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
s.getGpuFn = func(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.DeviceInfo {
|
||||
g := ml.DeviceInfo{DeviceID: ml.DeviceID{Library: "Metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 24 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
return []ml.DeviceInfo{g}
|
||||
}
|
||||
select {
|
||||
case resp := <-d.req.successCh:
|
||||
@@ -389,7 +394,7 @@ closeWait:
|
||||
s.loadedMu.Unlock()
|
||||
}
|
||||
|
||||
func TestGetRunner(t *testing.T) {
|
||||
func TestSchedGetRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 3*time.Second)
|
||||
defer done()
|
||||
|
||||
@@ -398,8 +403,9 @@ func TestGetRunner(t *testing.T) {
|
||||
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}, nil)
|
||||
t.Setenv("OLLAMA_MAX_QUEUE", "1")
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getCpuFn = getCpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = a.newServer
|
||||
slog.Info("a")
|
||||
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration)
|
||||
@@ -442,10 +448,11 @@ func TestGetRunner(t *testing.T) {
|
||||
b.ctxDone()
|
||||
}
|
||||
|
||||
func TestExpireRunner(t *testing.T) {
|
||||
func TestSchedExpireRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 20*time.Millisecond)
|
||||
defer done()
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
model: &Model{ModelPath: "foo"},
|
||||
@@ -456,13 +463,14 @@ func TestExpireRunner(t *testing.T) {
|
||||
}
|
||||
|
||||
var f *ggml.GGML
|
||||
gpus := discover.GpuInfoList{}
|
||||
gpus := []ml.DeviceInfo{}
|
||||
systemInfo := ml.SystemInfo{}
|
||||
server := &mockLlm{vramSize: 10, vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
s.newServerFn = func(systemInfo ml.SystemInfo, gpus []ml.DeviceInfo, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
|
||||
server.modelPath = model
|
||||
return server, nil
|
||||
}
|
||||
s.load(req, f, gpus, false)
|
||||
s.load(req, f, systemInfo, gpus, false)
|
||||
|
||||
select {
|
||||
case err := <-req.errCh:
|
||||
@@ -490,19 +498,16 @@ func TestExpireRunner(t *testing.T) {
|
||||
}
|
||||
|
||||
// TODO - add one scenario that triggers the bogus finished event with positive ref count
|
||||
func TestPrematureExpired(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
func TestSchedPrematureExpired(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 1000*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
// Same model, same request
|
||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil, nil)
|
||||
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 100 * time.Millisecond}, nil)
|
||||
s := InitScheduler(ctx)
|
||||
s.getGpuFn = func(ctx context.Context, runners []discover.FilteredRunnerDiscovery) discover.GpuInfoList {
|
||||
g := discover.GpuInfo{DeviceID: ml.DeviceID{Library: "metal"}}
|
||||
g.TotalMemory = 24 * format.GigaByte
|
||||
g.FreeMemory = 12 * format.GigaByte
|
||||
return []discover.GpuInfo{g}
|
||||
}
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.getGpuFn = getGpuFn
|
||||
s.getSystemInfoFn = getSystemInfoFn
|
||||
s.newServerFn = scenario1a.newServer
|
||||
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
@@ -537,7 +542,7 @@ func TestPrematureExpired(t *testing.T) {
|
||||
time.Sleep(5 * time.Millisecond)
|
||||
}
|
||||
|
||||
func TestUseLoadedRunner(t *testing.T) {
|
||||
func TestSchedUseLoadedRunner(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
req := &LlmRequest{
|
||||
ctx: ctx,
|
||||
@@ -564,10 +569,10 @@ func TestUseLoadedRunner(t *testing.T) {
|
||||
require.Equal(t, req, fin)
|
||||
}
|
||||
|
||||
func TestUpdateFreeSpace(t *testing.T) {
|
||||
func TestSchedUpdateFreeSpace(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
gpus := discover.GpuInfoList{
|
||||
gpus := []ml.DeviceInfo{
|
||||
{
|
||||
DeviceID: ml.DeviceID{
|
||||
ID: "1",
|
||||
@@ -597,6 +602,7 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
r2 := &runnerRef{llama: llm2, gpus: gpuIDs, numParallel: 1}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
@@ -607,7 +613,7 @@ func TestUpdateFreeSpace(t *testing.T) {
|
||||
require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory)
|
||||
}
|
||||
|
||||
func TestFindRunnerToUnload(t *testing.T) {
|
||||
func TestSchedFindRunnerToUnload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
@@ -615,6 +621,7 @@ func TestFindRunnerToUnload(t *testing.T) {
|
||||
r2 := &runnerRef{sessionDuration: 2, numParallel: 1}
|
||||
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.loadedMu.Lock()
|
||||
s.loaded["a"] = r1
|
||||
s.loaded["b"] = r2
|
||||
@@ -627,7 +634,7 @@ func TestFindRunnerToUnload(t *testing.T) {
|
||||
require.Equal(t, r1, resp)
|
||||
}
|
||||
|
||||
func TestNeedsReload(t *testing.T) {
|
||||
func TestSchedNeedsReload(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
@@ -674,13 +681,14 @@ func TestNeedsReload(t *testing.T) {
|
||||
require.False(t, resp)
|
||||
}
|
||||
|
||||
func TestUnloadAllRunners(t *testing.T) {
|
||||
func TestSchedUnloadAllRunners(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond)
|
||||
defer done()
|
||||
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
llm2 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
s.unloadAllRunners()
|
||||
|
||||
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||
@@ -696,7 +704,7 @@ func TestUnloadAllRunners(t *testing.T) {
|
||||
require.True(t, llm2.closeCalled)
|
||||
}
|
||||
|
||||
func TestUnload(t *testing.T) {
|
||||
func TestSchedUnload(t *testing.T) {
|
||||
llm1 := &mockLlm{vramByGPU: map[ml.DeviceID]uint64{}}
|
||||
r1 := &runnerRef{llama: llm1, numParallel: 1}
|
||||
r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1}
|
||||
@@ -706,13 +714,14 @@ func TestUnload(t *testing.T) {
|
||||
require.Nil(t, r2.model)
|
||||
}
|
||||
|
||||
func TestAlreadyCanceled(t *testing.T) {
|
||||
func TestSchedAlreadyCanceled(t *testing.T) {
|
||||
ctx, done := context.WithTimeout(t.Context(), 500*time.Millisecond)
|
||||
defer done()
|
||||
dctx, done2 := context.WithCancel(ctx)
|
||||
done2()
|
||||
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}, nil)
|
||||
s := InitScheduler(ctx)
|
||||
s.waitForRecovery = 10 * time.Millisecond
|
||||
slog.Info("scenario1a")
|
||||
s.pendingReqCh <- scenario1a.req
|
||||
require.Len(t, s.pendingReqCh, 1)
|
||||
@@ -745,8 +754,12 @@ func (s *mockLlm) ModelPath() string {
|
||||
return s.modelPath
|
||||
}
|
||||
|
||||
func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) ([]ml.DeviceID, error) {
|
||||
func (s *mockLlm) Load(ctx context.Context, sytemInfo ml.SystemInfo, gpus []ml.DeviceInfo, requireFull bool) ([]ml.DeviceID, error) {
|
||||
if requireFull {
|
||||
if len(gpus) == 0 {
|
||||
slog.Info("mockLlm.Load CPU based load")
|
||||
return nil, nil
|
||||
}
|
||||
for _, g := range gpus {
|
||||
if g.FreeMemory >= s.vramSize {
|
||||
return []ml.DeviceID{g.DeviceID}, nil
|
||||
@@ -767,8 +780,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||
return s.completionResp
|
||||
}
|
||||
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string, truncate bool) ([]float32, int, error) {
|
||||
return s.embeddingResp, 0, s.embeddingRespErr
|
||||
}
|
||||
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
@@ -125,7 +125,7 @@ func (p *Parser) parseToolCall() *api.ToolCall {
|
||||
}
|
||||
|
||||
var args map[string]any
|
||||
if found, i := findArguments(p.buffer); found == nil {
|
||||
if found, i := findArguments(tool, p.buffer); found == nil {
|
||||
return nil
|
||||
} else {
|
||||
args = found
|
||||
@@ -219,7 +219,7 @@ func findTool(tools []api.Tool, buf []byte) (*api.Tool, int) {
|
||||
// objects for functions that have all-optional parameters
|
||||
// e.g. `{"name": "get_conditions", "arguments": {}}` will work but
|
||||
// `{"name": "get_conditions"}` will not currently work
|
||||
func findArguments(buffer []byte) (map[string]any, int) {
|
||||
func findArguments(tool *api.Tool, buffer []byte) (map[string]any, int) {
|
||||
if len(buffer) == 0 {
|
||||
return nil, 0
|
||||
}
|
||||
@@ -269,27 +269,30 @@ func findArguments(buffer []byte) (map[string]any, int) {
|
||||
|
||||
var findObject func(obj map[string]any) (map[string]any, bool)
|
||||
findObject = func(obj map[string]any) (map[string]any, bool) {
|
||||
findMap := func(name string, obj map[string]any) (map[string]any, bool) {
|
||||
if args, ok := obj[name].(map[string]any); ok {
|
||||
return args, true
|
||||
}
|
||||
if argsStr, ok := obj[name].(string); ok {
|
||||
var argsData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||
return argsData, ok
|
||||
}
|
||||
}
|
||||
return nil, false
|
||||
}
|
||||
if _, hasName := obj["name"]; hasName {
|
||||
if args, ok := obj["arguments"].(map[string]any); ok {
|
||||
if args, ok := findMap("arguments", obj); ok {
|
||||
return args, true
|
||||
}
|
||||
if argsStr, ok := obj["arguments"].(string); ok {
|
||||
var argsData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||
return argsData, ok
|
||||
}
|
||||
}
|
||||
if args, ok := obj["parameters"].(map[string]any); ok {
|
||||
if args, ok := findMap("parameters", obj); ok {
|
||||
return args, true
|
||||
}
|
||||
if argsStr, ok := obj["parameters"].(string); ok {
|
||||
var argsData map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &argsData); err == nil {
|
||||
return argsData, ok
|
||||
}
|
||||
}
|
||||
return nil, true
|
||||
}
|
||||
if args, ok := findMap(tool.Function.Name, obj); ok {
|
||||
return args, true
|
||||
}
|
||||
|
||||
for _, v := range obj {
|
||||
switch child := v.(type) {
|
||||
|
||||
@@ -1033,6 +1033,7 @@ func TestFindArguments(t *testing.T) {
|
||||
name string
|
||||
buffer []byte
|
||||
want map[string]any
|
||||
tool string
|
||||
}{
|
||||
{
|
||||
name: "empty string",
|
||||
@@ -1290,11 +1291,29 @@ func TestFindArguments(t *testing.T) {
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "simple tool call",
|
||||
tool: "get_temperature",
|
||||
buffer: []byte(`{"get_temperature": {"format": "fahrenheit", "location": "San Francisco, CA"}}`),
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "stringified simple tool call",
|
||||
tool: "get_temperature",
|
||||
buffer: []byte(`{"get_temperature": "{\"format\": \"fahrenheit\", \"location\": \"San Francisco, CA\"}"}`),
|
||||
want: map[string]any{
|
||||
"format": "fahrenheit",
|
||||
"location": "San Francisco, CA",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, _ := findArguments(tt.buffer)
|
||||
got, _ := findArguments(&api.Tool{Function: api.ToolFunction{Name: tt.tool}}, tt.buffer)
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("scanArguments() args mismatch (-got +want):\n%s", diff)
|
||||
|
||||
Reference in New Issue
Block a user