Compare commits
21 Commits
v0.12.4-rc
...
v0.12.5-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d32249c74 | ||
|
|
d681cd7c29 | ||
|
|
47298fce39 | ||
|
|
4a48937ef1 | ||
|
|
967a82f52f | ||
|
|
bbbc73d637 | ||
|
|
15e3611d3d | ||
|
|
77060d462c | ||
|
|
1b91d4dda1 | ||
|
|
7d965258ce | ||
|
|
6a62b894c7 | ||
|
|
90d429f5a8 | ||
|
|
1fc35f1260 | ||
|
|
aa45f7ce27 | ||
|
|
4e5d862ec4 | ||
|
|
303be9304c | ||
|
|
bd15eba4e4 | ||
|
|
bc71278670 | ||
|
|
918231931c | ||
|
|
04c1849878 | ||
|
|
2c2f4deaa9 |
@@ -936,7 +936,7 @@ func (t *ThinkValue) UnmarshalJSON(data []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\")")
|
||||
return fmt.Errorf("think must be a boolean or string (\"high\", \"medium\", \"low\", true, or false)")
|
||||
}
|
||||
|
||||
// MarshalJSON implements json.Marshaler
|
||||
|
||||
@@ -2,11 +2,12 @@ package discover
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
@@ -60,17 +61,14 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList {
|
||||
DependencyPath: dev.LibraryPath,
|
||||
DriverMajor: dev.DriverMajor,
|
||||
DriverMinor: dev.DriverMinor,
|
||||
ComputeMajor: dev.ComputeMajor,
|
||||
ComputeMinor: dev.ComputeMinor,
|
||||
}
|
||||
if dev.Library == "CUDA" || dev.Library == "ROCm" {
|
||||
info.MinimumMemory = 457 * format.MebiByte
|
||||
}
|
||||
if dev.Library == "ROCm" {
|
||||
info.Compute = fmt.Sprintf("gfx%x%02x", dev.ComputeMajor, dev.ComputeMinor)
|
||||
if rocmDir != "" {
|
||||
info.DependencyPath = append(info.DependencyPath, rocmDir)
|
||||
}
|
||||
} else {
|
||||
info.Compute = fmt.Sprintf("%d.%d", dev.ComputeMajor, dev.ComputeMinor)
|
||||
if dev.Library == "ROCm" && rocmDir != "" {
|
||||
info.DependencyPath = append(info.DependencyPath, rocmDir)
|
||||
}
|
||||
resp = append(resp, info)
|
||||
}
|
||||
@@ -146,3 +144,35 @@ func GetSystemInfo() SystemInfo {
|
||||
GPUs: gpus,
|
||||
}
|
||||
}
|
||||
|
||||
func cudaJetpack() string {
|
||||
if runtime.GOARCH == "arm64" && runtime.GOOS == "linux" {
|
||||
if CudaTegra != "" {
|
||||
ver := strings.Split(CudaTegra, ".")
|
||||
if len(ver) > 0 {
|
||||
return "jetpack" + ver[0]
|
||||
}
|
||||
} else if data, err := os.ReadFile("/etc/nv_tegra_release"); err == nil {
|
||||
r := regexp.MustCompile(` R(\d+) `)
|
||||
m := r.FindSubmatch(data)
|
||||
if len(m) != 2 {
|
||||
slog.Info("Unexpected format for /etc/nv_tegra_release. Set JETSON_JETPACK to select version")
|
||||
} else {
|
||||
if l4t, err := strconv.Atoi(string(m[1])); err == nil {
|
||||
// Note: mapping from L4t -> JP is inconsistent (can't just subtract 30)
|
||||
// https://developer.nvidia.com/embedded/jetpack-archive
|
||||
switch l4t {
|
||||
case 35:
|
||||
return "jetpack5"
|
||||
case 36:
|
||||
return "jetpack6"
|
||||
default:
|
||||
// Newer Jetson systems use the SBSU runtime
|
||||
slog.Debug("unrecognized L4T version", "nv_tegra_release", string(data))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -78,6 +78,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
}
|
||||
|
||||
slog.Info("discovering available GPUs...")
|
||||
requested := envconfig.LLMLibrary()
|
||||
jetpack := cudaJetpack()
|
||||
|
||||
// For our initial discovery pass, we gather all the known GPUs through
|
||||
// all the libraries that were detected. This pass may include GPUs that
|
||||
@@ -86,6 +88,14 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev
|
||||
// times concurrently leading to memory contention
|
||||
for dir := range libDirs {
|
||||
var dirs []string
|
||||
if dir != "" {
|
||||
if requested != "" && filepath.Base(dir) != requested {
|
||||
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
|
||||
continue
|
||||
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if dir == "" {
|
||||
dirs = []string{LibOllamaPath}
|
||||
} else {
|
||||
@@ -442,15 +452,18 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||
cmd.Stderr = os.Stderr
|
||||
}
|
||||
// cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
extraDone := make([]bool, len(extraEnvs))
|
||||
for i := range cmd.Env {
|
||||
cmp := strings.SplitN(cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ollamaLibDirs, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else {
|
||||
for j := range extraEnvs {
|
||||
if extraDone[j] {
|
||||
@@ -467,6 +480,9 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s
|
||||
if pathNeeded {
|
||||
cmd.Env = append(cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
cmd.Env = append(cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ollamaLibDirs, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i := range extraDone {
|
||||
if !extraDone[i] {
|
||||
cmd.Env = append(cmd.Env, extraEnvs[i])
|
||||
|
||||
@@ -37,9 +37,10 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"?
|
||||
UnreliableFreeMemory bool
|
||||
|
||||
// GPU information
|
||||
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
Compute string `json:"compute"` // Compute Capability or gfx
|
||||
filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
ComputeMajor int `json:"compute_major"` // Compute Capability or gfx
|
||||
ComputeMinor int `json:"compute_minor"`
|
||||
|
||||
// Driver Information - TODO no need to put this on each GPU
|
||||
DriverMajor int `json:"driver_major,omitempty"`
|
||||
@@ -173,7 +174,7 @@ func (l GpuInfoList) FlashAttentionSupported() bool {
|
||||
for _, gpu := range l {
|
||||
supportsFA := gpu.Library == "cpu" ||
|
||||
gpu.Name == "Metal" || gpu.Library == "Metal" ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7) ||
|
||||
(gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier
|
||||
gpu.Library == "ROCm"
|
||||
|
||||
if !supportsFA {
|
||||
|
||||
@@ -38,26 +38,14 @@ Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
|
||||
|
||||
## LLM libraries
|
||||
|
||||
Ollama includes multiple LLM libraries compiled for different GPUs and CPU vector features. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library. `cpu_avx2` will perform the best, followed by `cpu_avx` and the slowest but most compatible is `cpu`. Rosetta emulation under MacOS will work with the `cpu` library.
|
||||
|
||||
In the server log, you will see a message that looks something like this (varies from release to release):
|
||||
|
||||
```
|
||||
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v12 rocm_v5]
|
||||
```
|
||||
Ollama includes multiple LLM libraries compiled for different GPU libraries and versions. Ollama tries to pick the best one based on the capabilities of your system. If this autodetection has problems, or you run into other problems (e.g. crashes in your GPU) you can workaround this by forcing a specific LLM library.
|
||||
|
||||
**Experimental LLM Library Override**
|
||||
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass autodetection, so for example, if you have a CUDA card, but want to force the CPU LLM library with AVX2 vector support, use:
|
||||
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to limit autodetection, so for example, if you have both CUDA and AMD GPUs, but want to force the CUDA v13 only, use:
|
||||
|
||||
```shell
|
||||
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
|
||||
```
|
||||
|
||||
You can see what features your CPU has with the following.
|
||||
|
||||
```shell
|
||||
cat /proc/cpuinfo| grep flags | head -1
|
||||
OLLAMA_LLM_LIBRARY="cuda_v13" ollama serve
|
||||
```
|
||||
|
||||
## Installing older or pre-release versions on Linux
|
||||
|
||||
@@ -243,7 +243,6 @@ func (kv KV) OllamaEngineRequired() bool {
|
||||
"gemma3",
|
||||
"gemma3n",
|
||||
"mistral3",
|
||||
"qwen3",
|
||||
"qwen3moe",
|
||||
"llama4",
|
||||
"mllama",
|
||||
|
||||
@@ -17,16 +17,21 @@ func TestBlueSky(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
func TestUnicode(t *testing.T) {
|
||||
@@ -34,10 +39,15 @@ func TestUnicode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
req := api.ChatRequest{
|
||||
// DeepSeek has a Unicode tokenizer regex, making it a unicode torture test
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Prompt: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
Model: "deepseek-coder-v2:16b-lite-instruct-q2_K", // TODO is there an ollama-engine model we can switch to and keep the coverage?
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "天空为什么是蓝色的?", // Why is the sky blue?
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -57,9 +67,14 @@ func TestUnicode(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 100)
|
||||
|
||||
DoGenerate(ctx, t, client, req, []string{
|
||||
DoChat(ctx, t, client, req, []string{
|
||||
"散射", // scattering
|
||||
"频率", // frequency
|
||||
}, 120*time.Second, 120*time.Second)
|
||||
@@ -69,9 +84,14 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "gemma2:2b",
|
||||
Prompt: "Output some smily face emoji",
|
||||
req := api.ChatRequest{
|
||||
Model: "gemma2:2b",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Output some smily face emoji",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -83,7 +103,7 @@ func TestExtendedUnicodeOutput(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"😀", "😊", "😁", "😂", "😄", "😃"}, 120*time.Second, 120*time.Second)
|
||||
}
|
||||
|
||||
func TestUnicodeModelDir(t *testing.T) {
|
||||
@@ -108,14 +128,19 @@ func TestUnicodeModelDir(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, req, blueSkyExpected)
|
||||
ChatTestHelper(ctx, t, req, blueSkyExpected)
|
||||
}
|
||||
|
||||
@@ -20,9 +20,9 @@ import (
|
||||
)
|
||||
|
||||
// Send multiple requests in parallel (concurrently) to a single model and ensure responses are expected
|
||||
func TestConcurrentGenerate(t *testing.T) {
|
||||
func TestConcurrentChat(t *testing.T) {
|
||||
// Assumes all requests have the same model
|
||||
req, resp := GenerateRequests()
|
||||
req, resp := ChatRequests()
|
||||
numParallel := int(envconfig.NumParallel() + 1)
|
||||
iterLimit := 3
|
||||
|
||||
@@ -57,7 +57,7 @@ func TestConcurrentGenerate(t *testing.T) {
|
||||
slog.Info("Starting", "thread", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
DoChat(ctx, t, client, req[k], resp[k], 120*time.Second, 20*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
@@ -163,7 +163,7 @@ chooseModels:
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
reqs, resps := GenerateRequests()
|
||||
reqs, resps := ChatRequests()
|
||||
for j := 0; j < 3; j++ {
|
||||
if time.Now().Sub(started) > softTimeout {
|
||||
slog.Info("exceeded soft timeout, winding down test")
|
||||
@@ -171,8 +171,8 @@ chooseModels:
|
||||
}
|
||||
k := r.Int() % len(reqs)
|
||||
reqs[k].Model = chosenModels[i]
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Prompt)
|
||||
DoGenerate(ctx, t, client, reqs[k], resps[k],
|
||||
slog.Info("Starting", "model", reqs[k].Model, "iteration", j, "request", reqs[k].Messages[0].Content)
|
||||
DoChat(ctx, t, client, reqs[k], resps[k],
|
||||
120*time.Second, // Be extra patient for the model to load initially
|
||||
10*time.Second, // Once results start streaming, fail if they stall
|
||||
)
|
||||
|
||||
@@ -21,9 +21,14 @@ func TestLongInputContext(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Oh, don’t speak to me of Austria. Perhaps I don’t understand things, but Austria never has wished, and does not wish, for war. She is betraying us! Russia alone must save Europe. Our gracious sovereign recognizes his high vocation and will be true to it. That is the one thing I have faith in! Our good and wonderful sovereign has to perform the noblest role on earth, and he is so virtuous and noble that God will not forsake him. He will fulfill his vocation and crush the hydra of revolution, which has become more terrible than ever in the person of this murderer and villain! We alone must avenge the blood of the just one.... Whom, I ask you, can we rely on?... England with her commercial spirit will not and cannot understand the Emperor Alexander’s loftiness of soul. She has refused to evacuate Malta. She wanted to find, and still seeks, some secret motive in our actions. What answer did Novosíltsev get? None. The English have not understood and cannot understand the self-abnegation of our Emperor who wants nothing for himself, but only desires the good of mankind. And what have they promised? Nothing! And what little they have promised they will not perform! Prussia has always declared that Buonaparte is invincible, and that all Europe is powerless before him.... And I don’t believe a word that Hardenburg says, or Haugwitz either. This famous Prussian neutrality is just a trap. I have faith only in God and the lofty destiny of our adored monarch. He will save Europe! What country is this referring to?",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -36,7 +41,7 @@ func TestLongInputContext(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"russia", "german", "france", "england", "austria", "prussia", "europe", "individuals", "coalition", "conflict"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
@@ -48,9 +53,14 @@ func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: smol,
|
||||
Prompt: "Write me a story in english with a lot of emojis",
|
||||
req := api.ChatRequest{
|
||||
Model: smol,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Write me a story in english with a lot of emojis",
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"temperature": 0,
|
||||
@@ -63,12 +73,12 @@ func TestContextExhaustion(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatalf("PullIfMissing failed: %v", err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{"once", "upon", "lived", "sunny", "cloudy", "clear", "water", "time", "travel", "world"}, 120*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
// Send multiple generate requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelGenerateWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
modelOverride := "gpt-oss:20b"
|
||||
req, resp := GenerateRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
@@ -155,7 +165,7 @@ func TestGenerateWithHistory(t *testing.T) {
|
||||
|
||||
// Send multiple chat requests with prior context and ensure the response is coherant and expected
|
||||
func TestParallelChatWithHistory(t *testing.T) {
|
||||
modelOverride := ollamaEngineChatModels[0] // Most recent ollama engine model
|
||||
modelOverride := "gpt-oss:20b"
|
||||
req, resp := ChatRequests()
|
||||
numParallel := 2
|
||||
iterLimit := 2
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
|
||||
// First run of this scenario on a target system will take a long time to download
|
||||
// ~1.5TB of models. Set a sufficiently large -timeout for your network speed
|
||||
func TestLibraryModelsGenerate(t *testing.T) {
|
||||
func TestLibraryModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -43,9 +43,14 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
t.Skip(fmt.Sprintf("Skipping %s architecture %s != %s", model, arch, targetArch))
|
||||
}
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0.1,
|
||||
@@ -58,13 +63,13 @@ func TestLibraryModelsGenerate(t *testing.T) {
|
||||
anyResp = []string{"select", "from"}
|
||||
} else if model == "granite3-guardian" || model == "shieldgemma" || model == "llama-guard3" || model == "bespoke-minicheck" {
|
||||
anyResp = []string{"yes", "no", "safe", "unsafe"}
|
||||
} else if model == "openthinker" || model == "nexusraven" {
|
||||
} else if model == "openthinker" {
|
||||
anyResp = []string{"plugin", "im_sep", "components", "function call"}
|
||||
} else if model == "starcoder" || model == "starcoder2" || model == "magicoder" || model == "deepseek-coder" {
|
||||
req.Prompt = "def fibonacci():"
|
||||
req.Messages[0].Content = "def fibonacci():"
|
||||
anyResp = []string{"f(n)", "sequence", "n-1", "main()", "__main__", "while"}
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, anyResp, 120*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -34,17 +34,22 @@ func TestVisionModels(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: v.model,
|
||||
Prompt: "what does the text in this image say?",
|
||||
req := api.ChatRequest{
|
||||
Model: v.model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "what does the text in this image say?",
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
},
|
||||
},
|
||||
Stream: &stream,
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
Images: []api.ImageData{
|
||||
image,
|
||||
},
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
@@ -56,8 +61,15 @@ func TestVisionModels(t *testing.T) {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// Preload to skip if we're less than 80% on GPU to avoid extremely slow tests
|
||||
err = client.Generate(ctx, &api.GenerateRequest{Model: req.Model}, func(response api.GenerateResponse) error { return nil })
|
||||
if err != nil {
|
||||
t.Fatalf("failed to load model %s: %s", req.Model, err)
|
||||
}
|
||||
skipIfNotGPULoaded(ctx, t, client, req.Model, 80)
|
||||
|
||||
// llava models on CPU can be quite slow to start
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, []string{resp}, 240*time.Second, 30*time.Second)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
func TestModelsGenerate(t *testing.T) {
|
||||
func TestModelsChat(t *testing.T) {
|
||||
softTimeout, hardTimeout := getTimeouts(t)
|
||||
slog.Info("Setting timeouts", "soft", softTimeout, "hard", hardTimeout)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), hardTimeout)
|
||||
@@ -66,15 +66,23 @@ func TestModelsGenerate(t *testing.T) {
|
||||
}
|
||||
}
|
||||
// TODO - fiddle with context size
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: blueSkyPrompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
DoGenerate(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
DoChat(ctx, t, client, req, blueSkyExpected, 120*time.Second, 30*time.Second)
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -128,8 +136,9 @@ func TestModelsEmbed(t *testing.T) {
|
||||
}
|
||||
}
|
||||
req := api.EmbeddingRequest{
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
Model: model,
|
||||
Prompt: "why is the sky blue?",
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
@@ -139,6 +148,10 @@ func TestModelsEmbed(t *testing.T) {
|
||||
if err != nil {
|
||||
t.Fatalf("embeddings call failed %s", err)
|
||||
}
|
||||
defer func() {
|
||||
// best effort unload once we're done with the model
|
||||
client.Generate(ctx, &api.GenerateRequest{Model: req.Model, KeepAlive: &api.Duration{Duration: 0}}, func(rsp api.GenerateResponse) error { return nil })
|
||||
}()
|
||||
if len(resp.Embedding) == 0 {
|
||||
t.Errorf("zero length embedding response")
|
||||
}
|
||||
|
||||
@@ -173,9 +173,14 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
slog.Info("skipping long prompt", "model", model, "num_ctx", numCtx, "gpu_percent", gpuPercent)
|
||||
continue
|
||||
}
|
||||
req := api.GenerateRequest{
|
||||
Model: model,
|
||||
Prompt: tc.prompt,
|
||||
req := api.ChatRequest{
|
||||
Model: model,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: tc.prompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 20 * time.Second}, // long enough to ensure a ps returns
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
@@ -184,7 +189,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
},
|
||||
}
|
||||
atLeastOne := false
|
||||
var resp api.GenerateResponse
|
||||
var resp api.ChatResponse
|
||||
|
||||
stream := false
|
||||
req.Stream = &stream
|
||||
@@ -198,7 +203,7 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
)
|
||||
defer cancel()
|
||||
|
||||
err = client.Generate(genCtx, &req, func(rsp api.GenerateResponse) error {
|
||||
err = client.Chat(genCtx, &req, func(rsp api.ChatResponse) error {
|
||||
resp = rsp
|
||||
return nil
|
||||
})
|
||||
@@ -214,13 +219,13 @@ func doModelPerfTest(t *testing.T, chatModels []string) {
|
||||
}
|
||||
loaded = true
|
||||
for _, expResp := range tc.anyResp {
|
||||
if strings.Contains(strings.ToLower(resp.Response), expResp) {
|
||||
if strings.Contains(strings.ToLower(resp.Message.Content), expResp) {
|
||||
atLeastOne = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !atLeastOne {
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Response)
|
||||
t.Fatalf("response didn't contain expected values: ctx:%d expected:%v response:%s ", numCtx, tc.anyResp, resp.Message.Content)
|
||||
}
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
|
||||
@@ -74,9 +74,14 @@ func TestQuantization(t *testing.T) {
|
||||
}
|
||||
|
||||
stream := true
|
||||
genReq := api.GenerateRequest{
|
||||
Model: newName,
|
||||
Prompt: blueSkyPrompt,
|
||||
chatReq := api.ChatRequest{
|
||||
Model: newName,
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: blueSkyPrompt,
|
||||
},
|
||||
},
|
||||
KeepAlive: &api.Duration{Duration: 3 * time.Second},
|
||||
Options: map[string]any{
|
||||
"seed": 42,
|
||||
@@ -91,8 +96,8 @@ func TestQuantization(t *testing.T) {
|
||||
reqCtx, reqCancel := context.WithCancel(ctx)
|
||||
atLeastOne := false
|
||||
var buf bytes.Buffer
|
||||
genfn := func(response api.GenerateResponse) error {
|
||||
buf.Write([]byte(response.Response))
|
||||
chatfn := func(response api.ChatResponse) error {
|
||||
buf.Write([]byte(response.Message.Content))
|
||||
fullResp := strings.ToLower(buf.String())
|
||||
for _, resp := range blueSkyExpected {
|
||||
if strings.Contains(fullResp, resp) {
|
||||
@@ -108,14 +113,14 @@ func TestQuantization(t *testing.T) {
|
||||
done := make(chan int)
|
||||
var genErr error
|
||||
go func() {
|
||||
genErr = client.Generate(reqCtx, &genReq, genfn)
|
||||
genErr = client.Chat(reqCtx, &chatReq, chatfn)
|
||||
done <- 0
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
if genErr != nil && !atLeastOne {
|
||||
t.Fatalf("failed with %s request prompt %s ", genReq.Model, genReq.Prompt)
|
||||
t.Fatalf("failed with %s request prompt %s ", chatReq.Model, chatReq.Messages[0].Content)
|
||||
}
|
||||
case <-ctx.Done():
|
||||
t.Error("outer test context done while waiting for generate")
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
@@ -24,7 +25,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/app/lifecycle"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
@@ -38,6 +38,7 @@ var (
|
||||
|
||||
// Note: add newer models at the top of the list to test them first
|
||||
ollamaEngineChatModels = []string{
|
||||
"qwen3-coder:30b",
|
||||
"gpt-oss:20b",
|
||||
"gemma3n:e2b",
|
||||
"mistral-small3.2:latest",
|
||||
@@ -46,6 +47,7 @@ var (
|
||||
"qwen2.5-coder:latest",
|
||||
"qwen2.5vl:3b",
|
||||
"qwen3:0.6b", // dense
|
||||
"qwen3:1.7b", // dense
|
||||
"qwen3:30b", // MOE
|
||||
"gemma3:1b",
|
||||
"llama3.1:latest",
|
||||
@@ -265,16 +267,16 @@ var (
|
||||
"Explain the physics involved in them. Be breif in your reply",
|
||||
"Explain the chemistry involved in them. Be breif in your reply",
|
||||
"What are common myths related to them? Be brief in your reply",
|
||||
"What are common fairytales related to them? Be brief in your reply",
|
||||
"Can they form if there is no rain? Be breif in your reply",
|
||||
"Can they form if there are no clouds? Be breif in your reply",
|
||||
"Do they happen on other planets? Be brief in your reply",
|
||||
}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "end", "gold", "fortune", "blessing", "prosperity", "magic", "shower", "sky", "shimmer", "light", "storm", "sunny"}
|
||||
rainbowExpected = []string{"water", "droplet", "mist", "glow", "refract", "reflect", "scatter", "particles", "wave", "color", "spectrum", "raindrop", "atmosphere", "frequency", "shower", "sky", "shimmer", "light", "storm", "sunny", "sunburst", "phenomenon", "mars", "venus", "jupiter"}
|
||||
)
|
||||
|
||||
func init() {
|
||||
lifecycle.InitLogging()
|
||||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||||
slog.SetDefault(logger)
|
||||
custom := os.Getenv("OLLAMA_TEST_DEFAULT_MODEL")
|
||||
if custom != "" {
|
||||
slog.Info("setting default test model to " + custom)
|
||||
@@ -335,6 +337,7 @@ func GetTestEndpoint() (*api.Client, string) {
|
||||
|
||||
var serverMutex sync.Mutex
|
||||
var serverReady bool
|
||||
var serverLogFile string
|
||||
|
||||
func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
// Make sure the server has been built
|
||||
@@ -361,8 +364,9 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
t.Setenv("OLLAMA_HOST", ollamaHost)
|
||||
}
|
||||
|
||||
logDir := t.TempDir()
|
||||
slog.Info("starting server", "url", ollamaHost)
|
||||
done, err := lifecycle.SpawnServer(ctx, "../ollama")
|
||||
done, err := SpawnServer(ctx, "../ollama", logDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to start server: %w", err)
|
||||
}
|
||||
@@ -385,6 +389,36 @@ func startServer(t *testing.T, ctx context.Context, ollamaHost string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func SpawnServer(ctx context.Context, command, logDir string) (chan int, error) {
|
||||
done := make(chan int)
|
||||
fp, err := os.CreateTemp(logDir, "ollama-server-*.log")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
}
|
||||
serverLogFile = fp.Name()
|
||||
|
||||
cmd := exec.CommandContext(ctx, command, "serve")
|
||||
cmd.Stderr = fp
|
||||
cmd.Stdout = fp
|
||||
|
||||
go func() {
|
||||
slog.Info("starting server...")
|
||||
if err := cmd.Run(); err != nil {
|
||||
// "signal: killed" expected
|
||||
if !strings.Contains(err.Error(), "signal") {
|
||||
slog.Info("failed to run server", "error", err)
|
||||
}
|
||||
}
|
||||
var code int
|
||||
if cmd.ProcessState != nil {
|
||||
code = cmd.ProcessState.ExitCode()
|
||||
}
|
||||
slog.Info("server exited")
|
||||
done <- code
|
||||
}()
|
||||
return done, nil
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *api.Client, modelName string) error {
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
@@ -445,12 +479,6 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
client, testEndpoint := GetTestEndpoint()
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
if err := startServer(t, ctx, testEndpoint); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -478,36 +506,32 @@ func InitServerConnection(ctx context.Context, t *testing.T) (*api.Client, strin
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
fp, err := os.Open(serverLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to open server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
slog.Error("failed to read server log", "logfile", serverLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err := os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, genReq api.GenerateRequest, anyResp []string) {
|
||||
func ChatTestHelper(ctx context.Context, t *testing.T, req api.ChatRequest, anyResp []string) {
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
if err := PullIfMissing(ctx, client, genReq.Model); err != nil {
|
||||
if err := PullIfMissing(ctx, client, req.Model); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
DoGenerate(ctx, t, client, genReq, anyResp, 30*time.Second, 10*time.Second)
|
||||
DoChat(ctx, t, client, req, anyResp, 30*time.Second, 10*time.Second)
|
||||
}
|
||||
|
||||
func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq api.GenerateRequest, anyResp []string, initialTimeout, streamTimeout time.Duration) []int {
|
||||
@@ -726,8 +750,14 @@ func skipIfNotGPULoaded(ctx context.Context, t *testing.T, client *api.Client, m
|
||||
loaded := []string{}
|
||||
for _, m := range models.Models {
|
||||
loaded = append(loaded, m.Name)
|
||||
if m.Name != model {
|
||||
continue
|
||||
if strings.Contains(model, ":") {
|
||||
if m.Name != model {
|
||||
continue
|
||||
}
|
||||
} else if strings.Contains(m.Name, ":") {
|
||||
if !strings.HasPrefix(m.Name, model+":") {
|
||||
continue
|
||||
}
|
||||
}
|
||||
gpuPercent := 0
|
||||
switch {
|
||||
|
||||
@@ -160,7 +160,15 @@ func (c *Causal) Init(backend ml.Backend, dtype ml.DType, maxSequences, capacity
|
||||
if c.swaMemorySize == 0 {
|
||||
c.swaMemorySize = c.swaWindowSize
|
||||
}
|
||||
if int(c.swaMemorySize) > capacity {
|
||||
// We will allocate space in the cache for the stop token, which won't be part of a follow on
|
||||
// sequence, so allocate an extra token of storage to ensure that we can jump back without
|
||||
// causing a cache break. As an optimization, only do this when we have parallel sequences
|
||||
// because the extra token will live in the batch buffer and won't get overwritten if we
|
||||
// only have a single sequence.
|
||||
if c.swaMemorySize != math.MaxInt32 && maxSequences > 1 {
|
||||
c.swaMemorySize = max(c.swaMemorySize, c.swaWindowSize+1)
|
||||
}
|
||||
if int(c.swaMemorySize) >= capacity {
|
||||
c.swaMemorySize = math.MaxInt32
|
||||
}
|
||||
|
||||
@@ -214,7 +222,6 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
|
||||
c.curLoc, err = c.findStartLoc()
|
||||
}
|
||||
if err != nil {
|
||||
slog.Warn("unable to find a kv cache slot", "cache", c)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -288,23 +295,44 @@ func (c *Causal) updateSlidingWindow() {
|
||||
return
|
||||
}
|
||||
|
||||
type lowestPosition struct {
|
||||
pos int32
|
||||
curBatch bool
|
||||
}
|
||||
|
||||
// create a map of unique sequences to the lowest position in that sequence
|
||||
lowestPos := make(map[int]int32)
|
||||
lowestPos := make(map[int]lowestPosition)
|
||||
for i := range c.curPositions {
|
||||
seq := c.curSequences[i]
|
||||
|
||||
pos, ok := lowestPos[seq]
|
||||
lowest, ok := lowestPos[seq]
|
||||
if !ok {
|
||||
pos = c.curPositions[i]
|
||||
} else if c.curPositions[i] < pos {
|
||||
pos = c.curPositions[i]
|
||||
lowest = lowestPosition{pos: c.curPositions[i], curBatch: true}
|
||||
} else if c.curPositions[i] < lowest.pos {
|
||||
lowest.pos = c.curPositions[i]
|
||||
}
|
||||
|
||||
lowestPos[seq] = pos
|
||||
lowestPos[seq] = lowest
|
||||
}
|
||||
|
||||
// for any sequences are not part of this batch, clean up any tokens
|
||||
// that are no longer needed after the processing of the previous
|
||||
// batch
|
||||
for seq, seqRange := range c.cellRanges {
|
||||
if _, ok := lowestPos[seq]; !ok {
|
||||
var last int32
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
|
||||
lowestPos[seq] = lowestPosition{pos: last + 1, curBatch: false}
|
||||
}
|
||||
}
|
||||
|
||||
// delete any entries that are beyond the window of the oldest position in the sequence
|
||||
for seq, pos := range lowestPos {
|
||||
for seq, lowest := range lowestPos {
|
||||
oldRange, ok := c.cellRanges[seq]
|
||||
if !ok {
|
||||
continue
|
||||
@@ -314,13 +342,13 @@ func (c *Causal) updateSlidingWindow() {
|
||||
|
||||
for i := oldRange.min; i <= oldRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
if c.cells[i].pos < pos-c.swaMemorySize {
|
||||
if c.cells[i].pos < lowest.pos-c.swaMemorySize {
|
||||
c.cells[i].sequences = slices.DeleteFunc(c.cells[i].sequences, func(s int) bool { return s == seq })
|
||||
} else {
|
||||
newRange.min = min(newRange.min, i)
|
||||
newRange.max = max(newRange.max, i)
|
||||
}
|
||||
if c.cells[i].pos >= pos-c.swaWindowSize {
|
||||
if lowest.curBatch && c.cells[i].pos >= lowest.pos-c.swaWindowSize {
|
||||
c.curCellRange.min = min(c.curCellRange.min, i)
|
||||
c.curCellRange.max = max(c.curCellRange.max, i)
|
||||
}
|
||||
@@ -657,9 +685,11 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
|
||||
// for sliding window, check that the window of the new sequence is contained in
|
||||
// the window of what we are storing
|
||||
var first int32 = math.MaxInt32
|
||||
var last int32 = -1
|
||||
for i := seqRange.min; i <= seqRange.max; i++ {
|
||||
if slices.Contains(c.cells[i].sequences, seq) {
|
||||
first = min(first, c.cells[i].pos)
|
||||
last = max(last, c.cells[i].pos)
|
||||
}
|
||||
}
|
||||
@@ -668,10 +698,8 @@ func (c *Causal) CanResume(seq int, pos int32) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
lastWindowStart := max(0, last-c.swaMemorySize)
|
||||
posWindowStart := max(0, pos-c.swaWindowSize)
|
||||
|
||||
return posWindowStart >= lastWindowStart
|
||||
return posWindowStart >= first && pos <= last+1
|
||||
}
|
||||
|
||||
func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
|
||||
@@ -96,6 +96,86 @@ func TestSWA(t *testing.T) {
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWASeparateBatches(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWACache(1, nil)
|
||||
defer cache.Close()
|
||||
|
||||
cache.Init(backend, ml.DTypeF16, 2, 16, 2)
|
||||
|
||||
x := float32(math.Inf(-1))
|
||||
|
||||
tests := []testCase{
|
||||
{
|
||||
name: "First seq 0",
|
||||
in: []float32{1, 2},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{1, 2},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 0",
|
||||
in: []float32{3, 4},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{2, 3, 4},
|
||||
expectedShape: []int{1, 1, 3},
|
||||
expectedMask: []float32{
|
||||
0, 0, x,
|
||||
x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "First seq 1",
|
||||
in: []float32{5, 6},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{0, 1},
|
||||
expected: []float32{5, 6},
|
||||
expectedShape: []int{1, 1, 2},
|
||||
expectedMask: []float32{
|
||||
0, x,
|
||||
0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Second seq 1",
|
||||
in: []float32{7, 8},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{1, 1},
|
||||
pos: []int32{2, 3},
|
||||
expected: []float32{6, 3, 4, 7, 8},
|
||||
expectedShape: []int{1, 1, 5},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0, x,
|
||||
x, x, x, 0, 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "Third seq 0",
|
||||
in: []float32{9, 10},
|
||||
inShape: []int{1, 1, 2},
|
||||
seqs: []int{0, 0},
|
||||
pos: []int32{4, 5},
|
||||
expected: []float32{9, 10, 3, 4},
|
||||
expectedShape: []int{1, 1, 4},
|
||||
expectedMask: []float32{
|
||||
0, x, x, 0,
|
||||
0, 0, x, x,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
testCache(t, backend, cache, tests)
|
||||
}
|
||||
|
||||
func TestSWAMem(t *testing.T) {
|
||||
backend := &testBackend{}
|
||||
cache := NewSWAMemCache(1, 3, nil)
|
||||
@@ -431,15 +511,15 @@ func TestCanResume(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3},
|
||||
Sequences: []int{0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4},
|
||||
Sequences: []int{0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4}, 1, 1, 4)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5}, 1, 1, 5)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// with window size 4, nothing has slid out of the window yet
|
||||
@@ -455,18 +535,21 @@ func TestCanResume(t *testing.T) {
|
||||
if !cache.CanResume(0, 3) {
|
||||
t.Errorf("CanResume(0, 3) = false, want true (latest position)")
|
||||
}
|
||||
if !cache.CanResume(0, 4) {
|
||||
t.Errorf("CanResume(0, 4) = false, want true (latest position)")
|
||||
}
|
||||
|
||||
// shift window by adding position 4
|
||||
// shift window by adding position 5
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{4, 5},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{5},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{5, 6}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{6}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
@@ -503,28 +586,28 @@ func TestCanResumeSWAMem(t *testing.T) {
|
||||
defer context.Close()
|
||||
|
||||
err := cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0},
|
||||
Positions: []int32{0, 1, 2, 3, 4, 5, 6},
|
||||
Sequences: []int{0, 0, 0, 0, 0, 0, 0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6}, 1, 1, 6)
|
||||
tensor := context.FromFloatSlice([]float32{1, 2, 3, 4, 5, 6, 7}, 1, 1, 7)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// shift window by adding position 6
|
||||
// shift window by adding position 7
|
||||
err = cache.StartForward(context, input.Batch{
|
||||
Positions: []int32{6, 7},
|
||||
Sequences: []int{0, 0},
|
||||
Positions: []int32{7},
|
||||
Sequences: []int{0},
|
||||
}, false)
|
||||
if err != nil {
|
||||
t.Fatalf("StartForward failed: %v", err)
|
||||
}
|
||||
|
||||
cache.SetLayer(0)
|
||||
tensor = context.FromFloatSlice([]float32{7, 8}, 1, 1, 2)
|
||||
tensor = context.FromFloatSlice([]float32{8}, 1, 1, 1)
|
||||
cache.Put(context, tensor, tensor)
|
||||
|
||||
// only the latest position has overlapping windows
|
||||
|
||||
@@ -266,11 +266,18 @@ func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
}
|
||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
||||
if gpus[i].FreeMemory < overhead+gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
|
||||
var compute string
|
||||
if gpus[i].Library == "ROCm" {
|
||||
compute = fmt.Sprintf("gfx%x%02x", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
} else {
|
||||
compute = fmt.Sprintf("%d.%d", gpus[i].ComputeMajor, gpus[i].ComputeMinor)
|
||||
}
|
||||
|
||||
slog.Debug("gpu has too little memory to allocate any layers",
|
||||
"id", gpus[i].ID,
|
||||
"library", gpus[i].Library,
|
||||
"variant", gpus[i].Variant,
|
||||
"compute", gpus[i].Compute,
|
||||
"compute", compute,
|
||||
"driver", fmt.Sprintf("%d.%d", gpus[i].DriverMajor, gpus[i].DriverMinor),
|
||||
"name", gpus[i].Name,
|
||||
"total", format.HumanBytes2(gpus[i].TotalMemory),
|
||||
|
||||
@@ -359,20 +359,22 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
s.cmd.Stderr = s.status
|
||||
s.cmd.SysProcAttr = LlamaServerSysProcAttr
|
||||
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
|
||||
// Always filter down the set of GPUs in case there are any unsupported devices that might crash
|
||||
envWorkarounds := gpus.GetVisibleDevicesEnv()
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path variable with our adjusted version
|
||||
pathNeeded := true
|
||||
ollamaPathNeeded := true
|
||||
envWorkaroundDone := make([]bool, len(envWorkarounds))
|
||||
for i := range s.cmd.Env {
|
||||
cmp := strings.SplitN(s.cmd.Env[i], "=", 2)
|
||||
if strings.EqualFold(cmp[0], pathEnv) {
|
||||
s.cmd.Env[i] = pathEnv + "=" + pathEnvVal
|
||||
pathNeeded = false
|
||||
} else if strings.EqualFold(cmp[0], "OLLAMA_LIBRARY_PATH") {
|
||||
s.cmd.Env[i] = "OLLAMA_LIBRARY_PATH=" + strings.Join(ggmlPaths, string(filepath.ListSeparator))
|
||||
ollamaPathNeeded = false
|
||||
} else if len(envWorkarounds) != 0 {
|
||||
for j, kv := range envWorkarounds {
|
||||
tmp := strings.SplitN(kv, "=", 2)
|
||||
@@ -386,6 +388,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a
|
||||
if pathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal)
|
||||
}
|
||||
if ollamaPathNeeded {
|
||||
s.cmd.Env = append(s.cmd.Env, "OLLAMA_LIBRARY_PATH="+strings.Join(ggmlPaths, string(filepath.ListSeparator)))
|
||||
}
|
||||
for i, done := range envWorkaroundDone {
|
||||
if !done {
|
||||
s.cmd.Env = append(s.cmd.Env, envWorkarounds[i])
|
||||
@@ -1481,7 +1486,10 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
serverReq.Header.Set("Content-Type", "application/json")
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
if err != nil && errors.Is(err, context.Canceled) {
|
||||
// client closed connection
|
||||
return err
|
||||
} else if err != nil {
|
||||
slog.Error("post predict", "error", err)
|
||||
return errors.New("model runner has unexpectedly stopped, this may be due to resource limitations or an internal error, check ollama server logs for details")
|
||||
}
|
||||
|
||||
424
middleware/openai.go
Normal file
424
middleware/openai.go
Normal file
@@ -0,0 +1,424 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
streamOptions *openai.StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.NewError(http.StatusInternalServerError, serr.Error()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := openai.ToChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// chat completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToChatCompletion(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
c := openai.ToCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &openai.Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := openai.ToUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []openai.CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToCompletion(w.id, generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(openai.ToEmbeddingList(w.model, embedResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.CompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq, err := openai.FromCompleteRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req openai.ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := openai.FromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, openai.NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, openai.NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
928
middleware/openai_test.go
Normal file
928
middleware/openai_test.go
Normal file
@@ -0,0 +1,928 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/openai"
|
||||
)
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
var (
|
||||
False = false
|
||||
True = true
|
||||
)
|
||||
|
||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(c.Request.Body)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
err := json.Unmarshal(bodyBytes, capturedRequest)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
|
||||
}
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func TestChatMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "chat handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with options",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"max_tokens": 999,
|
||||
"seed": 123,
|
||||
"stop": ["\n", "stop"],
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
"response_format": {"type": "json_object"}
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"num_predict": 999.0, // float because JSON doesn't distinguish between float and int
|
||||
"seed": 123.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
"temperature": 3.0,
|
||||
"frequency_penalty": 4.0,
|
||||
"presence_penalty": 5.0,
|
||||
"top_p": 6.0,
|
||||
},
|
||||
Format: json.RawMessage(`"json"`),
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with image content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "` + prefix + image + `"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Images: []api.ImageData{
|
||||
func() []byte {
|
||||
img, _ := base64.StdEncoding.DecodeString(image)
|
||||
return img
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "content": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Content: "Let's see what the weather is like in Paris",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and empty content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "content": "", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools and thinking content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "reasoning": "Let's see what the weather is like in Paris", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
Thinking: "Let's see what the weather is like in Paris",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool response with call ID",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id_abc", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||
{"role": "tool", "tool_call_id": "id_abc", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "The weather in Paris is 20 degrees Celsius",
|
||||
ToolName: "get_current_weather",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tool response with name",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]},
|
||||
{"role": "tool", "name": "get_current_weather", "content": "The weather in Paris is 20 degrees Celsius"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]any{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Role: "tool",
|
||||
Content: "The weather in Paris is 20 degrees Celsius",
|
||||
ToolName: "get_current_weather",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with streaming tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris?"}
|
||||
],
|
||||
"stream": true,
|
||||
"tools": [{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get the current weather",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"required": ["location"],
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state"
|
||||
},
|
||||
"unit": {
|
||||
"type": "string",
|
||||
"enum": ["celsius", "fahrenheit"]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris?",
|
||||
},
|
||||
},
|
||||
Tools: []api.Tool{
|
||||
{
|
||||
Type: "function",
|
||||
Function: api.ToolFunction{
|
||||
Name: "get_weather",
|
||||
Description: "Get the current weather",
|
||||
Parameters: struct {
|
||||
Type string `json:"type"`
|
||||
Defs any `json:"$defs,omitempty"`
|
||||
Items any `json:"items,omitempty"`
|
||||
Required []string `json:"required"`
|
||||
Properties map[string]api.ToolProperty `json:"properties"`
|
||||
}{
|
||||
Type: "object",
|
||||
Required: []string{"location"},
|
||||
Properties: map[string]api.ToolProperty{
|
||||
"location": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Description: "The city and state",
|
||||
},
|
||||
"unit": {
|
||||
Type: api.PropertyType{"string"},
|
||||
Enum: []any{"celsius", "fahrenheit"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": 2}
|
||||
]
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid message content type: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
defer func() { capturedRequest = nil }()
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return
|
||||
}
|
||||
if diff := cmp.Diff(&tc.req, capturedRequest); diff != "" {
|
||||
t.Fatalf("requests did not match: %+v", diff)
|
||||
}
|
||||
if diff := cmp.Diff(tc.err, errResp); diff != "" {
|
||||
t.Fatalf("errors did not match for %s:\n%s", tc.name, diff)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompletionsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "completions handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler stream with usage",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"stream": true,
|
||||
"stream_options": {"include_usage": true},
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 0.8,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &True,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": null,
|
||||
"stop": [1, 2],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid type for 'stop' field: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
body string
|
||||
req api.EmbedRequest
|
||||
err openai.ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.EmbedRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "embed handler single input",
|
||||
body: `{
|
||||
"input": "Hello",
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: "Hello",
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embed handler batch input",
|
||||
body: `{
|
||||
"input": ["Hello", "World"],
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: []any{"Hello", "World"},
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "embed handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: openai.ErrorResponse{
|
||||
Error: openai.Error{
|
||||
Message: "invalid input",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
endpoint := func(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var errResp openai.ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "list handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{
|
||||
Models: []api.ListModelResponse{
|
||||
{
|
||||
Name: "test-model",
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "test-model",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "library"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
name: "list handler empty output",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{})
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": null
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, tc := range testCases {
|
||||
router := gin.New()
|
||||
router.Use(ListMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "retrieve handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ShowResponse{
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
})
|
||||
},
|
||||
resp: `{
|
||||
"id":"test-model",
|
||||
"object":"model",
|
||||
"created":1686935002,
|
||||
"owned_by":"library"}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "retrieve handler error forwarding",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||
},
|
||||
resp: `{
|
||||
"error": {
|
||||
"code": null,
|
||||
"message": "model not found",
|
||||
"param": null,
|
||||
"type": "api_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, tc := range testCases {
|
||||
router := gin.New()
|
||||
router.Use(RetrieveMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -73,7 +73,7 @@ func (p ImageProcessor) bestResolution(img image.Point, possibleResolutions []im
|
||||
for i, res := range possibleResolutions {
|
||||
scaleW := float64(res.X) / float64(w)
|
||||
scaleH := float64(res.Y) / float64(h)
|
||||
scale := math.Min(scaleW, scaleH)
|
||||
scale := min(scaleW, scaleH)
|
||||
|
||||
scales[i] = scale
|
||||
}
|
||||
@@ -124,11 +124,11 @@ func (p ImageProcessor) maxResolution(imageRes, targetRes image.Point) image.Poi
|
||||
if scaleW < scaleH {
|
||||
newRes = image.Point{
|
||||
targetRes.X,
|
||||
int(math.Min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
int(min(math.Floor(float64(imageRes.Y)*scaleW), float64(targetRes.Y))),
|
||||
}
|
||||
} else {
|
||||
newRes = image.Point{
|
||||
int(math.Min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
int(min(math.Floor(float64(imageRes.X)*scaleH), float64(targetRes.X))),
|
||||
targetRes.Y,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,7 +53,7 @@ func (p ImageProcessor) fitToCanvas(imageSize, canvasSize image.Point) image.Poi
|
||||
tw := min(max(imageSize.X, p.imageSize), canvasSize.X)
|
||||
th := min(max(imageSize.Y, p.imageSize), canvasSize.Y)
|
||||
|
||||
r := math.Min(
|
||||
r := min(
|
||||
float64(tw)/float64(imageSize.X),
|
||||
float64(th)/float64(imageSize.Y),
|
||||
)
|
||||
@@ -89,10 +89,10 @@ func (p ImageProcessor) optimalTiledCanvas(imageSize image.Point) image.Point {
|
||||
if minUpscale == 0 {
|
||||
minUpscale = s
|
||||
} else {
|
||||
minUpscale = math.Min(minUpscale, s)
|
||||
minUpscale = min(minUpscale, s)
|
||||
}
|
||||
} else {
|
||||
maxDownscale = math.Max(maxDownscale, s)
|
||||
maxDownscale = max(maxDownscale, s)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
488
openai/openai.go
488
openai/openai.go
@@ -1,21 +1,18 @@
|
||||
// openai package provides middleware for partial compatibility with the OpenAI REST API
|
||||
// openai package provides core transformation logic for partial compatibility with the OpenAI REST API
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
@@ -86,7 +83,7 @@ type StreamOptions struct {
|
||||
}
|
||||
|
||||
type Reasoning struct {
|
||||
Effort *string `json:"effort,omitempty"`
|
||||
Effort string `json:"effort,omitempty"`
|
||||
}
|
||||
|
||||
type ChatCompletionRequest struct {
|
||||
@@ -220,11 +217,12 @@ func NewError(code int, message string) ErrorResponse {
|
||||
return ErrorResponse{Error{Type: etype, Message: message}}
|
||||
}
|
||||
|
||||
func toUsage(r api.ChatResponse) Usage {
|
||||
// ToUsage converts an api.ChatResponse to Usage
|
||||
func ToUsage(r api.ChatResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -256,7 +254,8 @@ func toToolCalls(tc []api.ToolCall) []ToolCall {
|
||||
return toolCalls
|
||||
}
|
||||
|
||||
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
// ToChatCompletion converts an api.ChatResponse to ChatCompletion
|
||||
func ToChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletion{
|
||||
Id: id,
|
||||
@@ -276,12 +275,13 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
|
||||
}
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}}, Usage: toUsage(r),
|
||||
}}, Usage: ToUsage(r),
|
||||
DebugInfo: r.DebugInfo,
|
||||
}
|
||||
}
|
||||
|
||||
func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
// ToChunk converts an api.ChatResponse to ChatCompletionChunk
|
||||
func ToChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChunk {
|
||||
toolCalls := toToolCalls(r.Message.ToolCalls)
|
||||
return ChatCompletionChunk{
|
||||
Id: id,
|
||||
@@ -305,15 +305,17 @@ func toChunk(id string, r api.ChatResponse, toolCallSent bool) ChatCompletionChu
|
||||
}
|
||||
}
|
||||
|
||||
func toUsageGenerate(r api.GenerateResponse) Usage {
|
||||
// ToUsageGenerate converts an api.GenerateResponse to Usage
|
||||
func ToUsageGenerate(r api.GenerateResponse) Usage {
|
||||
return Usage{
|
||||
PromptTokens: r.PromptEvalCount,
|
||||
CompletionTokens: r.EvalCount,
|
||||
TotalTokens: r.PromptEvalCount + r.EvalCount,
|
||||
PromptTokens: r.Metrics.PromptEvalCount,
|
||||
CompletionTokens: r.Metrics.EvalCount,
|
||||
TotalTokens: r.Metrics.PromptEvalCount + r.Metrics.EvalCount,
|
||||
}
|
||||
}
|
||||
|
||||
func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
// ToCompletion converts an api.GenerateResponse to Completion
|
||||
func ToCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return Completion{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
@@ -330,11 +332,12 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
|
||||
return nil
|
||||
}(r.DoneReason),
|
||||
}},
|
||||
Usage: toUsageGenerate(r),
|
||||
Usage: ToUsageGenerate(r),
|
||||
}
|
||||
}
|
||||
|
||||
func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
// ToCompleteChunk converts an api.GenerateResponse to CompletionChunk
|
||||
func ToCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
return CompletionChunk{
|
||||
Id: id,
|
||||
Object: "text_completion",
|
||||
@@ -354,7 +357,8 @@ func toCompleteChunk(id string, r api.GenerateResponse) CompletionChunk {
|
||||
}
|
||||
}
|
||||
|
||||
func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
// ToListCompletion converts an api.ListResponse to ListCompletion
|
||||
func ToListCompletion(r api.ListResponse) ListCompletion {
|
||||
var data []Model
|
||||
for _, m := range r.Models {
|
||||
data = append(data, Model{
|
||||
@@ -371,7 +375,8 @@ func toListCompletion(r api.ListResponse) ListCompletion {
|
||||
}
|
||||
}
|
||||
|
||||
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
// ToEmbeddingList converts an api.EmbedResponse to EmbeddingList
|
||||
func ToEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
if r.Embeddings != nil {
|
||||
var data []Embedding
|
||||
for i, e := range r.Embeddings {
|
||||
@@ -396,7 +401,8 @@ func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
|
||||
return EmbeddingList{}
|
||||
}
|
||||
|
||||
func toModel(r api.ShowResponse, m string) Model {
|
||||
// ToModel converts an api.ShowResponse to Model
|
||||
func ToModel(r api.ShowResponse, m string) Model {
|
||||
return Model{
|
||||
Id: m,
|
||||
Object: "model",
|
||||
@@ -405,7 +411,8 @@ func toModel(r api.ShowResponse, m string) Model {
|
||||
}
|
||||
}
|
||||
|
||||
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
// FromChatRequest converts a ChatCompletionRequest to api.ChatRequest
|
||||
func FromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
var messages []api.Message
|
||||
for _, msg := range r.Messages {
|
||||
toolName := ""
|
||||
@@ -560,13 +567,23 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
|
||||
}
|
||||
|
||||
var think *api.ThinkValue
|
||||
var effort string
|
||||
|
||||
if r.Reasoning != nil {
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.Reasoning.Effort,
|
||||
}
|
||||
effort = r.Reasoning.Effort
|
||||
} else if r.ReasoningEffort != nil {
|
||||
think = &api.ThinkValue{
|
||||
Value: *r.ReasoningEffort,
|
||||
effort = *r.ReasoningEffort
|
||||
}
|
||||
|
||||
if effort != "" {
|
||||
if !slices.Contains([]string{"high", "medium", "low", "none"}, effort) {
|
||||
return nil, fmt.Errorf("invalid reasoning value: '%s' (must be \"high\", \"medium\", \"low\", or \"none\")", effort)
|
||||
}
|
||||
|
||||
if effort == "none" {
|
||||
think = &api.ThinkValue{Value: false}
|
||||
} else {
|
||||
think = &api.ThinkValue{Value: effort}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -609,7 +626,8 @@ func fromCompletionToolCall(toolCalls []ToolCall) ([]api.ToolCall, error) {
|
||||
return apiToolCalls, nil
|
||||
}
|
||||
|
||||
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
// FromCompleteRequest converts a CompletionRequest to api.GenerateRequest
|
||||
func FromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
options := make(map[string]any)
|
||||
|
||||
switch stop := r.Stop.(type) {
|
||||
@@ -660,413 +678,3 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
|
||||
DebugRenderOnly: r.DebugRenderOnly,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type BaseWriter struct {
|
||||
gin.ResponseWriter
|
||||
}
|
||||
|
||||
type ChatWriter struct {
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
toolCallSent bool
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type CompleteWriter struct {
|
||||
stream bool
|
||||
streamOptions *StreamOptions
|
||||
id string
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type ListWriter struct {
|
||||
BaseWriter
|
||||
}
|
||||
|
||||
type RetrieveWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
type EmbedWriter struct {
|
||||
BaseWriter
|
||||
model string
|
||||
}
|
||||
|
||||
func (w *BaseWriter) writeError(data []byte) (int, error) {
|
||||
var serr api.StatusError
|
||||
err := json.Unmarshal(data, &serr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) writeResponse(data []byte) (int, error) {
|
||||
var chatResponse api.ChatResponse
|
||||
err := json.Unmarshal(data, &chatResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// chat chunk
|
||||
if w.stream {
|
||||
c := toChunk(w.id, chatResponse, w.toolCallSent)
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if !w.toolCallSent && len(c.Choices) > 0 && len(c.Choices[0].Delta.ToolCalls) > 0 {
|
||||
w.toolCallSent = true
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if chatResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsage(chatResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []ChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// chat completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ChatWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) writeResponse(data []byte) (int, error) {
|
||||
var generateResponse api.GenerateResponse
|
||||
err := json.Unmarshal(data, &generateResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// completion chunk
|
||||
if w.stream {
|
||||
c := toCompleteChunk(w.id, generateResponse)
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
c.Usage = &Usage{}
|
||||
}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if generateResponse.Done {
|
||||
if w.streamOptions != nil && w.streamOptions.IncludeUsage {
|
||||
u := toUsageGenerate(generateResponse)
|
||||
c.Usage = &u
|
||||
c.Choices = []CompleteChunkChoice{}
|
||||
d, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
// completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toCompletion(w.id, generateResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *CompleteWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *ListWriter) writeResponse(data []byte) (int, error) {
|
||||
var listResponse api.ListResponse
|
||||
err := json.Unmarshal(data, &listResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toListCompletion(listResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *ListWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) writeResponse(data []byte) (int, error) {
|
||||
var showResponse api.ShowResponse
|
||||
err := json.Unmarshal(data, &showResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// retrieve completion
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toModel(showResponse, w.model))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *RetrieveWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
|
||||
var embedResponse api.EmbedResponse
|
||||
err := json.Unmarshal(data, &embedResponse)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "application/json")
|
||||
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(data), nil
|
||||
}
|
||||
|
||||
func (w *EmbedWriter) Write(data []byte) (int, error) {
|
||||
code := w.ResponseWriter.Status()
|
||||
if code != http.StatusOK {
|
||||
return w.writeError(data)
|
||||
}
|
||||
|
||||
return w.writeResponse(data)
|
||||
}
|
||||
|
||||
func ListMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
w := &ListWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func RetrieveMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.ShowRequest{Name: c.Param("model")}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
// response writer
|
||||
w := &RetrieveWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: c.Param("model"),
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func CompletionsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req CompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
genReq, err := fromCompleteRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(genReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &CompleteWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func EmbeddingsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req EmbedRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if req.Input == "" {
|
||||
req.Input = []string{""}
|
||||
}
|
||||
|
||||
if req.Input == nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
if v, ok := req.Input.([]any); ok && len(v) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input, Dimensions: req.Dimensions}); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &EmbedWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
model: req.Model,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
func ChatMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var req ChatCompletionRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if len(req.Messages) == 0 {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
|
||||
return
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
|
||||
chatReq, err := fromChatRequest(req)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Request.Body = io.NopCloser(&b)
|
||||
|
||||
w := &ChatWriter{
|
||||
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
|
||||
stream: req.Stream,
|
||||
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
|
||||
streamOptions: req.StreamOptions,
|
||||
}
|
||||
|
||||
c.Writer = w
|
||||
|
||||
c.Next()
|
||||
}
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -82,10 +82,10 @@ type Sequence struct {
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
processingDuration time.Duration
|
||||
generationDuration time.Duration
|
||||
numDecoded int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -99,8 +99,6 @@ type NewSequenceParams struct {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -142,18 +140,17 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
}
|
||||
|
||||
return &Sequence{
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
samplingCtx: sc,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -438,8 +435,8 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
return nil
|
||||
}
|
||||
|
||||
err := s.lc.Decode(batch)
|
||||
if err != nil {
|
||||
t := time.Now()
|
||||
if err := s.lc.Decode(batch); err != nil {
|
||||
return fmt.Errorf("failed to decode batch: %w", err)
|
||||
}
|
||||
|
||||
@@ -459,9 +456,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
|
||||
continue
|
||||
}
|
||||
|
||||
seq.numDecoded += 1
|
||||
if seq.numDecoded == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
s.lc.Synchronize()
|
||||
seq.numDecoded++
|
||||
if seq.numDecoded > 1 {
|
||||
seq.generationDuration += time.Since(t)
|
||||
} else {
|
||||
seq.processingDuration += time.Since(t)
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -646,9 +646,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numDecoded,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.generationDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -91,10 +91,11 @@ type Sequence struct {
|
||||
doneReason llm.DoneReason
|
||||
|
||||
// Metrics
|
||||
startProcessingTime time.Time
|
||||
startGenerationTime time.Time
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
startedAt, lastUpdatedAt time.Time
|
||||
processingDuration time.Duration
|
||||
samplingDuration time.Duration
|
||||
numPredicted int
|
||||
numPromptInputs int
|
||||
}
|
||||
|
||||
type NewSequenceParams struct {
|
||||
@@ -108,8 +109,6 @@ type NewSequenceParams struct {
|
||||
func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSequenceParams) (*Sequence, error) {
|
||||
s.ready.Wait()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
inputs, ctxs, mmStore, err := s.inputs(prompt, images)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to process inputs: %w", err)
|
||||
@@ -164,20 +163,19 @@ func (s *Server) NewSequence(prompt string, images []llm.ImageData, params NewSe
|
||||
// TODO(jessegross): Ingest cached history for grammar
|
||||
|
||||
return &Sequence{
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
startProcessingTime: startTime,
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
ctxs: ctxs,
|
||||
mmStore: mmStore,
|
||||
inputs: inputs,
|
||||
numPromptInputs: len(inputs),
|
||||
numPredict: params.numPredict,
|
||||
pendingResponses: make([]string, 0),
|
||||
responses: make(chan string, 100),
|
||||
quit: make(chan bool, 1),
|
||||
embedding: make(chan []float32, 1),
|
||||
sampler: params.sampler,
|
||||
embeddingOnly: params.embedding,
|
||||
stop: params.stop,
|
||||
numKeep: params.numKeep,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -408,7 +406,7 @@ func (s *Server) run(ctx context.Context) {
|
||||
|
||||
supportsAsync := pooling.Type(s.model.Backend().Config().Uint("pooling_type")) == pooling.TypeNone
|
||||
|
||||
var activeBatch batchState
|
||||
var previousBatch batchState
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -417,16 +415,18 @@ func (s *Server) run(ctx context.Context) {
|
||||
panic(err)
|
||||
default:
|
||||
var err error
|
||||
activeBatch, err = s.forwardBatch(activeBatch)
|
||||
nextBatch, err := s.forwardBatch(previousBatch)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
if supportsAsync {
|
||||
go s.computeBatch(activeBatch)
|
||||
go s.computeBatch(nextBatch)
|
||||
} else {
|
||||
s.computeBatch(activeBatch)
|
||||
s.computeBatch(nextBatch)
|
||||
}
|
||||
|
||||
previousBatch = nextBatch
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -562,6 +562,13 @@ func (s *Server) forwardBatch(pendingBatch batchState) (nextBatch batchState, er
|
||||
seq.inputs = seq.inputs[len(seq.pendingInputs):]
|
||||
}
|
||||
|
||||
startedAt := time.Now()
|
||||
for i := range nextBatch.seqs {
|
||||
if nextBatch.seqs[i] != nil && nextBatch.seqs[i].startedAt.IsZero() {
|
||||
nextBatch.seqs[i].startedAt = startedAt
|
||||
}
|
||||
}
|
||||
|
||||
if resumeSeq != -1 {
|
||||
s.nextSeq = resumeSeq
|
||||
} else {
|
||||
@@ -682,6 +689,7 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
activeBatch.modelOutput)
|
||||
|
||||
outputs := activeBatch.modelOutput.Floats()
|
||||
t := time.Now()
|
||||
|
||||
logutil.Trace("computeBatch: logits ready", "batchID", activeBatch.id)
|
||||
|
||||
@@ -694,8 +702,10 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
continue
|
||||
}
|
||||
|
||||
seq.lastUpdatedAt = t
|
||||
if seq.numPredicted == 1 {
|
||||
seq.startGenerationTime = time.Now()
|
||||
seq.processingDuration = seq.lastUpdatedAt.Sub(seq.startedAt)
|
||||
seq.startedAt = seq.lastUpdatedAt
|
||||
}
|
||||
|
||||
// if done processing the prompt, generate an embedding and return
|
||||
@@ -774,6 +784,13 @@ func (s *Server) computeBatch(activeBatch batchState) {
|
||||
s.removeSequence(i, llm.DoneReasonConnectionClosed)
|
||||
}
|
||||
}
|
||||
|
||||
samplingDuration := time.Since(t)
|
||||
for i, seq := range s.seqs {
|
||||
if seq != nil && nextBatchTokens[i] != nil {
|
||||
s.seqs[i].samplingDuration += samplingDuration
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -887,9 +904,9 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
Done: true,
|
||||
DoneReason: seq.doneReason,
|
||||
PromptEvalCount: seq.numPromptInputs,
|
||||
PromptEvalDuration: seq.startGenerationTime.Sub(seq.startProcessingTime),
|
||||
PromptEvalDuration: seq.processingDuration,
|
||||
EvalCount: seq.numPredicted,
|
||||
EvalDuration: time.Since(seq.startGenerationTime),
|
||||
EvalDuration: seq.lastUpdatedAt.Sub(seq.startedAt) - seq.samplingDuration,
|
||||
}); err != nil {
|
||||
http.Error(w, fmt.Sprintf("failed to encode final response: %v", err), http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
@@ -179,7 +179,7 @@ function buildROCm() {
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& cmake --install build --component "HIP" --strip
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
rm -f $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906*
|
||||
Remove-Item -Path $script:DIST_DIR\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
247
server/routes.go
247
server/routes.go
@@ -37,8 +37,8 @@ import (
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/logutil"
|
||||
"github.com/ollama/ollama/middleware"
|
||||
"github.com/ollama/ollama/model/parsers"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -330,12 +330,18 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
if req.Suffix != "" {
|
||||
caps = append(caps, model.CapabilityInsert)
|
||||
}
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
// TODO(drifkin): consider adding a warning if it's false and the model
|
||||
// doesn't support thinking. It's not strictly required, but it can be a
|
||||
// hint that the user is on an older qwen3/r1 model that doesn't have an
|
||||
// updated template supporting thinking
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -1449,11 +1455,11 @@ func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
r.POST("/api/embeddings", s.EmbeddingsHandler)
|
||||
|
||||
// Inference (OpenAI compatibility)
|
||||
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
r.POST("/v1/chat/completions", middleware.ChatMiddleware(), s.ChatHandler)
|
||||
r.POST("/v1/completions", middleware.CompletionsMiddleware(), s.GenerateHandler)
|
||||
r.POST("/v1/embeddings", middleware.EmbeddingsMiddleware(), s.EmbedHandler)
|
||||
r.GET("/v1/models", middleware.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", middleware.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
if rc != nil {
|
||||
// wrap old with new
|
||||
@@ -1871,8 +1877,18 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
if len(req.Tools) > 0 {
|
||||
caps = append(caps, model.CapabilityTools)
|
||||
}
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
|
||||
modelCaps := m.Capabilities()
|
||||
if slices.Contains(modelCaps, model.CapabilityThinking) {
|
||||
caps = append(caps, model.CapabilityThinking)
|
||||
if req.Think == nil {
|
||||
req.Think = &api.ThinkValue{Value: true}
|
||||
}
|
||||
} else {
|
||||
if req.Think != nil && req.Think.Bool() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support thinking", req.Model)})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), name.String(), caps, req.Options, req.KeepAlive)
|
||||
@@ -1967,88 +1983,167 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
toolParser = tools.NewParser(m.Template.Template, req.Tools)
|
||||
}
|
||||
|
||||
type structuredOutputsState int
|
||||
const (
|
||||
structuredOutputsState_None structuredOutputsState = iota
|
||||
structuredOutputsState_ReadyToApply
|
||||
structuredOutputsState_Applying
|
||||
)
|
||||
|
||||
ch := make(chan any)
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: req.Format,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
structuredOutputsState := structuredOutputsState_None
|
||||
|
||||
for {
|
||||
var tb strings.Builder
|
||||
|
||||
currentFormat := req.Format
|
||||
// structured outputs via double request is enabled when:
|
||||
// 1. the model supports the thinking capability and
|
||||
// 2. it uses a built-in parser or our generic thinking parser
|
||||
|
||||
// Note that the current approach does not work for (potential future)
|
||||
// non-thinking models that emit anything before actual content. This
|
||||
// current approach uses the transition from parsed thinking content to
|
||||
// parsed non-thinking content as the signal to turn constraining on
|
||||
|
||||
if req.Format != nil && structuredOutputsState == structuredOutputsState_None && ((builtinParser != nil || thinkingState != nil) && slices.Contains(m.Capabilities(), model.CapabilityThinking)) {
|
||||
currentFormat = nil
|
||||
}
|
||||
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
// sets up new context given parent context per request
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
err := r.Completion(ctx, llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Images: images,
|
||||
Format: currentFormat,
|
||||
Options: opts,
|
||||
}, func(r llm.CompletionResponse) {
|
||||
res := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Message: api.Message{Role: "assistant", Content: r.Content},
|
||||
Done: r.Done,
|
||||
Metrics: api.Metrics{
|
||||
PromptEvalCount: r.PromptEvalCount,
|
||||
PromptEvalDuration: r.PromptEvalDuration,
|
||||
EvalCount: r.EvalCount,
|
||||
EvalDuration: r.EvalDuration,
|
||||
},
|
||||
}
|
||||
if r.Done {
|
||||
res.DoneReason = r.DoneReason.String()
|
||||
res.TotalDuration = time.Since(checkpointStart)
|
||||
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
|
||||
}
|
||||
|
||||
res.Message.Content = content
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
if builtinParser != nil {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser input", "parser", m.Config.Parser, "content", r.Content)
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
content, thinking, toolCalls, err := builtinParser.Add(r.Content, r.Done)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
res.Message.Thinking = thinkingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.Thinking = thinking
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
|
||||
tb.WriteString(thinking)
|
||||
// we are now receiving content from the model - we should start applying structured outputs
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && res.Message.Content != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
|
||||
if res.Message.Content != "" || res.Message.Thinking != "" || len(res.Message.ToolCalls) > 0 || r.Done {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser output", "parser", m.Config.Parser, "content", content, "thinking", thinking, "toolCalls", toolCalls, "done", r.Done)
|
||||
ch <- res
|
||||
} else {
|
||||
slog.Log(context.TODO(), logutil.LevelTrace, "builtin parser empty output", "parser", m.Config.Parser)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if thinkingState != nil {
|
||||
thinkingContent, remainingContent := thinkingState.AddContent(res.Message.Content)
|
||||
if thinkingContent == "" && remainingContent == "" && !r.Done {
|
||||
// need to accumulate more to decide what to send
|
||||
return
|
||||
}
|
||||
res.Message.Thinking = thinkingContent
|
||||
tb.WriteString(thinkingContent)
|
||||
// emit the collected thinking text before restarting with structured outputs and clear unstructured content
|
||||
// to avoid leaking mixed tokens like "</think>Hello"
|
||||
if structuredOutputsState == structuredOutputsState_None && req.Format != nil && tb.String() != "" && remainingContent != "" {
|
||||
structuredOutputsState = structuredOutputsState_ReadyToApply
|
||||
res.Message.Content = ""
|
||||
ch <- res
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
res.Message.Content = remainingContent
|
||||
}
|
||||
|
||||
if len(req.Tools) > 0 {
|
||||
toolCalls, content := toolParser.Add(res.Message.Content)
|
||||
if len(content) > 0 {
|
||||
res.Message.Content = content
|
||||
} else if len(toolCalls) > 0 {
|
||||
res.Message.ToolCalls = toolCalls
|
||||
res.Message.Content = ""
|
||||
} else if res.Message.Thinking != "" {
|
||||
// don't return
|
||||
} else {
|
||||
if r.Done {
|
||||
res.Message.Content = toolParser.Content()
|
||||
ch <- res
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
})
|
||||
if err != nil {
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply && strings.Contains(err.Error(), "context canceled") && c.Request.Context().Err() == nil {
|
||||
// only ignores error if it's a context cancellation due to setting structured outputs
|
||||
} else {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
ch <- res
|
||||
}); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
// ignored structured outputs cancellation falls through to here, start a new request with the structured outputs and updated prompt. use the
|
||||
if structuredOutputsState == structuredOutputsState_ReadyToApply {
|
||||
structuredOutputsState = structuredOutputsState_Applying
|
||||
msg := api.Message{
|
||||
Role: "assistant",
|
||||
Thinking: tb.String(),
|
||||
}
|
||||
|
||||
msgs = append(msgs, msg)
|
||||
prompt, _, err = chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, processedTools, req.Think)
|
||||
if err != nil {
|
||||
slog.Error("chat prompt error applying structured outputs", "error", err)
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
}
|
||||
// force constraining by terminating thinking header, the parser is already at this state
|
||||
// when the last message is thinking, the rendered for gpt-oss cannot disambiguate between having the
|
||||
// model continue thinking or ending thinking and outputting the final message.
|
||||
// TODO(parthsareen): consider adding prefill disambiguation logic to the renderer for structured outputs.
|
||||
if shouldUseHarmony(m) || (builtinParser != nil && m.Config.Parser == "harmony") {
|
||||
prompt += "<|end|><|start|>assistant<|channel|>final<|message|>"
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}()
|
||||
|
||||
|
||||
@@ -158,11 +158,26 @@ func TestGenerateChat(t *testing.T) {
|
||||
t.Errorf("expected status 400, got %d", w.Code)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"registry.ollama.ai/library/test:latest does not support thinking"}`); diff != "" {
|
||||
if diff := cmp.Diff(w.Body.String(), `{"error":"\"test\" does not support thinking"}`); diff != "" {
|
||||
t.Errorf("mismatch (-got +want):\n%s", diff)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("model can't think but think set false", func(t *testing.T) {
|
||||
think := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test",
|
||||
Messages: []api.Message{
|
||||
{Role: "user", Content: "Hello!"},
|
||||
},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing model", func(t *testing.T) {
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
|
||||
if w.Code != http.StatusBadRequest {
|
||||
@@ -1120,13 +1135,6 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
"The answer is 4.",
|
||||
true)
|
||||
|
||||
testChatRequest(t, "thinking disabled but template still adds think tag",
|
||||
"Simple question",
|
||||
" My thoughts </think> The answer.",
|
||||
"",
|
||||
" My thoughts </think> The answer.",
|
||||
false)
|
||||
|
||||
// Test streaming response with template-added <think>
|
||||
t.Run("streaming with thinking", func(t *testing.T) {
|
||||
var wg sync.WaitGroup
|
||||
@@ -1198,4 +1206,238 @@ func TestChatWithPromptEndingInThinkTag(t *testing.T) {
|
||||
t.Errorf("expected content %q, got %q", "Based on my analysis, the solution is straightforward.", got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart non-stream", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := false
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
var resp api.ChatResponse
|
||||
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if resp.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected thinking %q, got %q", "I am thinking through this problem. ", resp.Message.Thinking)
|
||||
}
|
||||
|
||||
if resp.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected content %q, got %q", `{"answer":"42"}`, resp.Message.Content)
|
||||
}
|
||||
|
||||
if !resp.Done {
|
||||
t.Errorf("expected response to be done")
|
||||
}
|
||||
|
||||
if resp.DoneReason != "stop" {
|
||||
t.Errorf("expected done reason stop, got %s", resp.DoneReason)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("structured outputs restart streaming", func(t *testing.T) {
|
||||
var (
|
||||
requestsMu sync.Mutex
|
||||
requests []llm.CompletionRequest
|
||||
wg sync.WaitGroup
|
||||
)
|
||||
|
||||
wg.Add(2)
|
||||
|
||||
format := json.RawMessage(`{"type":"object","properties":{"answer":{"type":"string"}}}`)
|
||||
|
||||
mock.CompletionFn = func(ctx context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
|
||||
defer wg.Done()
|
||||
|
||||
requestsMu.Lock()
|
||||
requests = append(requests, r)
|
||||
callNum := len(requests)
|
||||
requestsMu.Unlock()
|
||||
|
||||
switch callNum {
|
||||
case 1:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: " I am thinking through this problem. </think> {\"answer\":\"42\"}",
|
||||
Done: false,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
})
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(time.Second):
|
||||
t.Fatalf("timeout waiting for structured outputs cancellation")
|
||||
return nil
|
||||
}
|
||||
case 2:
|
||||
fn(llm.CompletionResponse{
|
||||
Content: `{"answer":"42"}`,
|
||||
Done: true,
|
||||
DoneReason: llm.DoneReasonStop,
|
||||
PromptEvalCount: 1,
|
||||
PromptEvalDuration: 1,
|
||||
EvalCount: 1,
|
||||
EvalDuration: 1,
|
||||
})
|
||||
return nil
|
||||
default:
|
||||
t.Fatalf("unexpected number of completion calls: %d", callNum)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
think := true
|
||||
streamRequest := true
|
||||
w := createRequest(t, s.ChatHandler, api.ChatRequest{
|
||||
Model: "test-thinking",
|
||||
Messages: []api.Message{{Role: "user", Content: "Please respond in JSON."}},
|
||||
Think: &api.ThinkValue{Value: think},
|
||||
Stream: &streamRequest,
|
||||
Format: format,
|
||||
})
|
||||
|
||||
wg.Wait()
|
||||
mock.CompletionFn = nil
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
if len(requests) != 2 {
|
||||
t.Fatalf("expected two completion calls, got %d", len(requests))
|
||||
}
|
||||
|
||||
if requests[0].Format != nil {
|
||||
t.Errorf("expected first completion format to be nil, got %q", requests[0].Format)
|
||||
}
|
||||
|
||||
if !bytes.Equal([]byte(format), []byte(requests[1].Format)) {
|
||||
t.Errorf("expected second completion format to match original format")
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(w.Body)
|
||||
var events []api.ChatResponse
|
||||
for {
|
||||
var event api.ChatResponse
|
||||
if err := decoder.Decode(&event); err == io.EOF {
|
||||
break
|
||||
} else if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
events = append(events, event)
|
||||
if event.Done {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(events) < 2 {
|
||||
t.Fatalf("expected at least two streaming events, got %d", len(events))
|
||||
}
|
||||
|
||||
first := events[0]
|
||||
if first.Message.Thinking != "I am thinking through this problem. " {
|
||||
t.Errorf("expected first event thinking %q, got %q", "I am thinking through this problem. ", first.Message.Thinking)
|
||||
}
|
||||
|
||||
if first.Message.Content != "" {
|
||||
t.Errorf("expected first event content to be empty, got %q", first.Message.Content)
|
||||
}
|
||||
|
||||
if first.Done {
|
||||
t.Error("expected first event to be non-terminal")
|
||||
}
|
||||
|
||||
last := events[len(events)-1]
|
||||
if last.Message.Thinking != "" {
|
||||
t.Errorf("expected final event thinking to be empty, got %q", last.Message.Thinking)
|
||||
}
|
||||
|
||||
if last.Message.Content != `{"answer":"42"}` {
|
||||
t.Errorf("expected final event content %q, got %q", `{"answer":"42"}`, last.Message.Content)
|
||||
}
|
||||
|
||||
if !last.Done {
|
||||
t.Error("expected final event to be done")
|
||||
}
|
||||
|
||||
if last.DoneReason != "stop" {
|
||||
t.Errorf("expected final done reason stop, got %s", last.DoneReason)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -229,8 +229,9 @@ func (s *Scheduler) processPending(ctx context.Context) {
|
||||
}
|
||||
|
||||
if runnerToExpire == nil {
|
||||
// Shouildn't happen
|
||||
slog.Error("runner to expire was nil!")
|
||||
// While we were performing load calculations, the loaded runner(s) unloaded in parallel
|
||||
// so findRunnerToUnload returned no runners. We'll try again and the loadedCount should be zero
|
||||
slog.Debug("runner to expire was nil, retrying")
|
||||
continue
|
||||
}
|
||||
// Trigger an expiration to unload once it's done
|
||||
|
||||
Reference in New Issue
Block a user