Compare commits

..

30 Commits

Author SHA1 Message Date
Roy Han
781585d9bd return 204 for cross-origin OPTIONS 2024-08-12 11:41:36 -07:00
Roy Han
b84a54be05 return 405 for bad method 2024-08-12 11:41:36 -07:00
royjhan
01d544d373 OpenAI: Simplify input output in testing (#5858)
* simplify input output

* direct comp

* in line image

* rm error pointer type

* update response testing

* lint
2024-08-12 10:33:34 -07:00
Josh
1dc3ef3aa9 Revert "server: speed up single gguf creates (#5898)" (#6323)
This reverts commit 8aac22438e.
2024-08-12 09:57:51 -07:00
Josh
8aac22438e server: speed up single gguf creates (#5898) 2024-08-12 09:28:55 -07:00
Jeffrey Morgan
15c2d8fe14 server: parallelize embeddings in API web handler instead of in subprocess runner (#6220)
For simplicity, perform parallelization of embedding requests in the API handler instead of offloading this to the subprocess runner. This keeps the scheduling story simpler as it builds on existing parallel requests, similar to existing text completion functionality.
2024-08-11 11:57:10 -07:00
Daniel Hiltgen
25906d72d1 llm: prevent loading too large models on windows (#5926)
Don't allow loading models that would lead to memory exhaustion (across vram, system memory and disk paging). This check was already applied on Linux but should also be applied on Windows as well.
2024-08-11 11:30:20 -07:00
CognitiveTech
023451ce47 add integration obook-summary (#6305) 2024-08-10 18:43:08 -07:00
Jesse Gross
9b53e39d8e Merge pull request #6258 from coolljt0725/fix_typo
server/download.go: Fix a typo in log
2024-08-09 17:19:48 -07:00
Michael Yang
97fae2df95 Merge pull request #6235 from Nicholas42/fix_line_endings
Set *.png and *.ico to be treated as binary files.
2024-08-09 17:06:30 -07:00
Michael Yang
160d9d4900 Merge pull request #6171 from ollama/mxyng/remove-temp
removeall to remove non-empty temp dirs
2024-08-09 15:47:13 -07:00
Nicholas Schwab
d4e6407464 Restrict text files with explicit line feeds to *.go.
This partially reverts b732beba6a. It
seems like explicitly setting all files to use line feeds was done due
to issues with the go linter, hence it can be restricted to those files
(https://github.com/ollama/ollama/pull/6235#issuecomment-2278745953).
2024-08-09 23:14:13 +02:00
Daniel Hiltgen
b7f7d8cd15 Merge pull request #6291 from dhiltgen/no_sparse_fail
Don't hard fail on sparse setup error
2024-08-09 12:30:25 -07:00
Daniel Hiltgen
2fa1db4345 Don't hard fail on sparse setup error
It seems this can fail in some casees, but proceed
with the download anyway.
2024-08-09 12:16:19 -07:00
Daniel Hiltgen
71b0945fc6 Merge pull request #6290 from dhiltgen/intel_npe
Harden intel boostrap for nil pointers
2024-08-09 12:14:42 -07:00
Daniel Hiltgen
5bca2e60a7 Harden intel boostrap for nil pointers 2024-08-09 11:31:38 -07:00
Nicholas42
67472e0e89 Also flag *.icns as binary 2024-08-09 13:41:20 +02:00
Daniel Hiltgen
e9aa5117c4 Merge pull request #6133 from dhiltgen/cuda_repo
Adjust arm cuda repo paths
2024-08-08 12:33:35 -07:00
Daniel Hiltgen
2473bdba5e Merge pull request #6182 from dhiltgen/more_patterns
Catch one more error log
2024-08-08 12:33:17 -07:00
Jesse Gross
7d1c0047fa Merge pull request #6247 from ollama/jessegross/layers
Store layers inside manifests consistently as values.
2024-08-08 10:46:43 -07:00
Jitang Lei
7b61eba471 server/download.go: Fix a typo in log
Signed-off-by: Jitang Lei <leijitang@outlook.com>
2024-08-08 20:28:01 +08:00
Jesse Gross
7edaf6e7e8 manifest: Store layers inside manifests consistently as values.
Commit 1829fb61 ("manifest: Fix crash on startup when trying to clean up
unused files (#5840)") changed the config layer stored in manifests
from a pointer to a value. This was done in order to avoid potential
nil pointer dereferences after it is deserialized from JSON in the
event that the field is missing.

This changes the Layers slice to also be stored by value. This enables
consistency in handling across the two objects.
2024-08-07 17:03:06 -07:00
Jesse Gross
97ec8cfd4e image: Clarify argument to WriteManifest is config
When creating a model the config layer is appended to the list of
layers and then the last layer is used as the config when writing the
manifest. This change directly uses the config layer to write the
manifest. There is no behavior change but it is less error prone.
2024-08-07 16:58:42 -07:00
royjhan
5b3a21b578 add metrics to docs (#6079) 2024-08-07 14:43:44 -07:00
Kyle Kelley
ad0c19dde4 Use llama3.1 in tools example (#5985)
* Use llama3.1 in tools example

* Update api.md
2024-08-07 17:20:50 -04:00
Nicholas Schwab
ce67706037 Set *.png and *.ico to be treated as binary files.
The change b732beba6 makes all files text files and sets lf as eol. This
will automatically change all files to have lf if they are touched by
git (e.g. via git status). This change cannot be stashed and makes it
hard to work with the repo (rebase and checkout don't really work). See
also #6183.

Here, we set the offending files (*.png and *.ico, but that might be
more in the future) to be treated as binary files and not be changed by
git.
2024-08-07 18:20:11 +02:00
Daniel Hiltgen
04210aa6dd Catch one more error log 2024-08-05 09:28:07 -07:00
Michael Yang
43f9d92008 close pid file 2024-08-05 00:41:16 -07:00
Michael Yang
ed6c8bfe57 removeall to remove non-empty temp dirs 2024-08-05 00:41:16 -07:00
Daniel Hiltgen
df3802a65f Adjust arm cuda repo paths
Ubuntu distros fail to install cuda drivers since aarch64 isn't valid
2024-08-01 17:22:25 -07:00
31 changed files with 532 additions and 1099 deletions

3
.gitattributes vendored
View File

@@ -1,2 +1,3 @@
llm/ext_server/* linguist-vendored
* text eol=lf
* text=auto
*.go text eol=lf

5
.gitmodules vendored
View File

@@ -1,7 +1,4 @@
[submodule "llama.cpp"]
path = llm/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
shallow = true
[submodule "llm/whisper.cpp"]
path = llm/whisper.cpp
url = git@github.com:ggerganov/whisper.cpp.git
shallow = true

View File

@@ -325,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
- [gollama](https://github.com/sammcj/gollama)
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
### Database

View File

@@ -36,13 +36,6 @@ func (e StatusError) Error() string {
// ImageData represents the raw binary data of an image file.
type ImageData []byte
type WhisperRequest struct {
Model string `json:"model,omitempty"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
}
// GenerateRequest describes a request sent by [Client.Generate]. While you
// have to specify the Model and Prompt fields, all the other fields have
// reasonable defaults for basic uses.
@@ -87,8 +80,6 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -114,10 +105,6 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
RunSpeech bool `json:"run_speech,omitempty"`
}
type Tools []Tool
@@ -140,7 +127,6 @@ type Message struct {
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Audio string `json:"audio,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
@@ -464,11 +450,6 @@ type GenerateResponse struct {
Metrics
}
type WhisperCompletion struct {
Text string `json:"text"`
Error string `json:"error,omitempty"`
}
// ModelDetails provides details about a model.
type ModelDetails struct {
ParentModel string `json:"parent_model"`

View File

@@ -38,7 +38,6 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -381,14 +380,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
if speech {
return generateInteractiveAudio(cmd, opts)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@@ -871,7 +862,6 @@ type runOptions struct {
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
Audio bool
}
type displayResponseState struct {
@@ -980,10 +970,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if opts.Audio {
req.RunSpeech = true
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
@@ -1069,30 +1055,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
// create temp wav file with the recorder package
if speech {
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
fmt.Print("Speech Mode\n\n")
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
request.Speech = &api.WhisperRequest{
Audio: tempFile.Name(),
}
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
@@ -1300,7 +1262,6 @@ func NewCLI() *cobra.Command {
RunE: RunHandler,
}
runCmd.Flags().Bool("speech", false, "Speech to text mode")
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")

View File

@@ -20,7 +20,6 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/types/errtypes"
)
@@ -52,40 +51,6 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
for {
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// create temp wav file with the recorder package
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
p.StopAndClear()
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
opts.Audio = true
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
}
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")

View File

@@ -669,7 +669,7 @@ curl http://localhost:11434/api/chat -d '{
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"model": "llama3.1",
"messages": [
{
"role": "user",
@@ -708,7 +708,7 @@ curl http://localhost:11434/api/chat -d '{
```json
{
"model": "mistral:7b-instruct-v0.3-q4_K_M",
"model": "llama3.1",
"created_at": "2024-07-22T20:33:28.123648Z",
"message": {
"role": "assistant",
@@ -1175,7 +1175,10 @@ curl http://localhost:11434/api/embed -d '{
"embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
]]
]],
"total_duration": 14143917,
"load_duration": 1019500,
"prompt_eval_count": 8
}
```

View File

@@ -1,83 +0,0 @@
# Speech to Text Prototype
### To run
`make {/path/to/whisper.cpp/server}`
- replace `whisperServer` in `routes.go` with path to server
## CLI
`./ollama run llama3 [PROMPT] --speech`
- processes voice audio with the provided prompt
`./ollama run llama3 --speech`
- enters interactive mode for continuous voice chat
- TODO: fix exiting interactive mode
Notes: uses default model
## api/generate
### Request fields
- `speech` (required):
- `audio` (required): path to audio file
- `model` (optional): path to whisper model, uses default if null
- `transcribe` (optional): if true, will transcribe and return the audio file
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `prompt` (optional): if not null, passed in with the transcribed audio
#### Transcription
```
curl http://localhost:11434/api/generate -d '{
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"transcribe": true,
"keep_alive": "1m"
},
"stream": false
}' | jq
```
#### Response Generation
```
curl http://localhost:11434/api/generate -d '{
"model": "llama3",
"prompt": "What do you think about this quote?",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"keep_alive": "1m"
},
"stream": false
}' | jq
```
## api/chat
### Request fields
- `model` (required): language model to chat with
- `speech` (optional):
- `model` (optional): path to whisper model, uses default if null
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `run_speech` (optional): either this flag must be true or `speech` must be passed in for speech mode to run
- `messages`/`message`/`audio` (required): path to audio file
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"keep_alive": "10m"
},
"messages": [
{
"role": "system",
"content": "You are a Canadian Nationalist"
},
{
"role": "user",
"content": "What do you think about this quote?",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
}
],
"stream": false
}' | jq
```

1
go.mod
View File

@@ -19,7 +19,6 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

2
go.sum
View File

@@ -115,8 +115,6 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=

View File

@@ -49,13 +49,9 @@ func PayloadsDir() (string, error) {
}
// Track our pid so we can clean up orphaned tmpdirs
pidFilePath := filepath.Join(tmpDir, "ollama.pid")
pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
if err != nil {
return "", err
}
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
return "", err
n := filepath.Join(tmpDir, "ollama.pid")
if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil {
return "", fmt.Errorf("failed to write pid file %s: %w", n, err)
}
// We create a distinct subdirectory for payloads within the tmpdir
@@ -67,37 +63,44 @@ func PayloadsDir() (string, error) {
// Best effort to clean up prior tmpdirs
func cleanupTmpDirs() {
dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*"))
matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid"))
if err != nil {
return
}
for _, d := range dirs {
info, err := os.Stat(d)
if err != nil || !info.IsDir() {
for _, match := range matches {
raw, err := os.ReadFile(match)
if errors.Is(err, os.ErrNotExist) {
slog.Debug("not a ollama runtime directory, skipping", "path", match)
continue
}
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
if err != nil {
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
// No pid, ignore this tmpdir
} else if err != nil {
slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err)
continue
}
pid, err := strconv.Atoi(string(raw))
if err != nil {
slog.Warn("failed to parse pid", "path", d, "error", err)
slog.Warn("invalid pid, skipping", "path", match, "error", err)
continue
}
proc, err := os.FindProcess(pid)
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
slog.Warn("found running ollama", "pid", pid, "path", d)
// Another running ollama, ignore this tmpdir
p, err := os.FindProcess(pid)
if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) {
slog.Warn("process still running, skipping", "pid", pid, "path", match)
continue
}
if err := os.Remove(d); err != nil {
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
if err := os.Remove(match); err != nil {
slog.Warn("could not cleanup stale pidfile", "path", match, "error", err)
}
runners := filepath.Join(filepath.Dir(match), "runners")
if err := os.RemoveAll(runners); err != nil {
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
}
if err := os.Remove(filepath.Dir(match)); err != nil {
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
}
}
}

View File

@@ -305,38 +305,41 @@ func GetGPUInfo() GpuInfoList {
// Intel
if envconfig.IntelGPU() {
oHandles = initOneAPIHandles()
// On windows we bundle the oneapi library one level above the runner dir
depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
}
if oHandles != nil && oHandles.oneapi != nil {
for d := range oHandles.oneapi.num_drivers {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
continue
// On windows we bundle the oneapi library one level above the runner dir
depPath = ""
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
}
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := range devCount {
gpuInfo := OneapiGPUInfo{
GpuInfo: GpuInfo{
Library: "oneapi",
},
driverIndex: int(d),
gpuIndex: int(i),
for d := range oHandles.oneapi.num_drivers {
if oHandles.oneapi == nil {
// shouldn't happen
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
continue
}
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
for i := range devCount {
gpuInfo := OneapiGPUInfo{
GpuInfo: GpuInfo{
Library: "oneapi",
},
driverIndex: int(d),
gpuIndex: int(i),
}
// TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DependencyPath = depPath
oneapiGPUs = append(oneapiGPUs, gpuInfo)
}
// TODO - split bootstrapping from updating free memory
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
// TODO - convert this to MinimumMemory based on testing...
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DependencyPath = depPath
oneapiGPUs = append(oneapiGPUs, gpuInfo)
}
}
}

View File

@@ -1223,9 +1223,7 @@ struct llama_server_context
res.result_json = json
{
{"id", res.id},
{"embedding", std::vector<float>(embd, embd + n_embd)},
{"timings", slot.get_formated_timings()},
};
}
}
@@ -3194,41 +3192,17 @@ int main(int argc, char **argv) {
prompt = "";
}
if (prompt.size() == 1) {
prompt = prompt[0];
}
// create and queue the task
json responses;
{
const int id_task = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
// get the result
task_result result = llama.queue_results.recv(id_task);
llama.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
// get the result
task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(task_id);
responses = result.result_json.value("results", std::vector<json>{result.result_json});
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
return a["id"] < b["id"];
});
json embeddings = json::array();
int prompt_n = 0;
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
prompt_n += elem.at("timings").at("prompt_n").get<int>();
}
// send the result
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
// send the result
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
});
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
Embedding(ctx context.Context, input string) ([]float32, error)
Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error
@@ -125,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
}
}
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
// On linux and windows, over-allocating CPU memory will almost always result in an error
// Darwin has fully dynamic swap so has no direct concept of free swap space
if runtime.GOOS != "darwin" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available {
@@ -882,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil
}
type EmbedRequest struct {
Content []string `json:"content"`
type EmbeddingRequest struct {
Content string `json:"content"`
}
type EmbedResponse struct {
Embedding [][]float32 `json:"embedding"`
PromptEvalCount int `json:"prompt_n"`
type EmbeddingResponse struct {
Embedding []float32 `json:"embedding"`
}
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
// each input will use a slot, so we need to acquire the semaphore for
// the number of inputs up to numParallel
slots := int64(min(len(input), s.numParallel))
if err := s.sem.Acquire(ctx, slots); err != nil {
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err)
return nil, err
}
defer s.sem.Release(slots)
defer s.sem.Release(1)
// Make sure the server is ready
status, err := s.getServerStatusRetry(ctx)
@@ -909,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
}
data, err := json.Marshal(EmbedRequest{Content: input})
data, err := json.Marshal(EmbeddingRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
if err != nil {
return nil, fmt.Errorf("error creating embed request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
r.Header.Set("Content-Type", "application/json")
resp, err := http.DefaultClient.Do(req)
resp, err := http.DefaultClient.Do(r)
if err != nil {
return nil, fmt.Errorf("do embedding request: %w", err)
}
@@ -936,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
return nil, fmt.Errorf("%s", body)
}
var e EmbedResponse
var e EmbeddingResponse
if err := json.Unmarshal(body, &e); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
}
return &e, nil
return e.Embedding, nil
}
type TokenizeRequest struct {

View File

@@ -26,6 +26,7 @@ var errorPrefixes = []string{
"cudaMalloc failed",
"\"ERR\"",
"error loading model",
"GGML_ASSERT",
}
func (w *StatusWriter) Write(b []byte) (int, error) {

Submodule llm/whisper.cpp deleted from 6739eb83c3

View File

@@ -7,27 +7,22 @@ import (
"io"
"net/http"
"net/http/httptest"
"reflect"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/ollama/ollama/api"
)
const (
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
imageURL = prefix + image
prefix = `data:image/jpeg;base64,`
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
)
func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
var False = false
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
@@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
func TestChatMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
name string
body string
req api.ChatRequest
err ErrorResponse
}
var capturedRequest *api.ChatRequest
testCases := []testCase{
{
Name: "chat handler",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
name: "chat handler",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello"}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
Name: "chat handler with image content",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
name: "chat handler with image content",
body: `{
"model": "test-model",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hello"
},
{
"type": "image_url",
"image_url": {
"url": "` + prefix + image + `"
}
}
]
}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "Hello",
},
{
Role: "user",
Images: []api.ImageData{
func() []byte {
img, _ := base64.StdEncoding.DecodeString(image)
return img
}(),
},
},
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
name: "chat handler with tools",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": "What's the weather like in Paris Today?"},
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
]
}`,
req: api.ChatRequest{
Model: "test-model",
Messages: []api.Message{
{
Role: "user",
Content: "What's the weather like in Paris Today?",
},
{
Role: "assistant",
ToolCalls: []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: map[string]interface{}{
"location": "Paris, France",
"format": "celsius",
},
},
},
},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
},
Options: map[string]any{
"temperature": 1.0,
"top_p": 1.0,
},
Stream: &False,
},
},
{
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
name: "chat handler error forwarding",
body: `{
"model": "test-model",
"messages": [
{"role": "user", "content": 2}
]
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid message content type: float64",
Type: "invalid_request_error",
},
},
},
}
@@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
@@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
name string
body string
req api.GenerateRequest
err ErrorResponse
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8)
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: &temp,
Stop: []string{"\n", "stop"},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if req.Prompt != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Prompt)
}
if req.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
}
stopTokens, ok := req.Options["stop"].([]any)
if !ok {
t.Fatalf("expected stop tokens to be a list")
}
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
}
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
name: "completions handler",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": 0.8,
"stop": ["\n", "stop"],
"suffix": "suffix"
}`,
req: api.GenerateRequest{
Model: "test-model",
Prompt: "Hello",
Options: map[string]any{
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"temperature": 1.6,
"top_p": 1.0,
"stop": []any{"\n", "stop"},
},
Suffix: "suffix",
Stream: &False,
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
name: "completions handler error forwarding",
body: `{
"model": "test-model",
"prompt": "Hello",
"temperature": null,
"stop": [1, 2],
"suffix": "suffix"
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid type for 'stop' field: float64",
Type: "invalid_request_error",
},
},
},
}
@@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
tc.Setup(t, req)
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
@@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
name string
body string
req api.EmbedRequest
err ErrorResponse
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
Name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
name: "embed handler single input",
body: `{
"input": "Hello",
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: "Hello",
Model: "test-model",
},
},
{
Name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
name: "embed handler batch input",
body: `{
"input": ["Hello", "World"],
"model": "test-model"
}`,
req: api.EmbedRequest{
Input: []any{"Hello", "World"},
Model: "test-model",
},
},
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
name: "embed handler error forwarding",
body: `{
"model": "test-model"
}`,
err: ErrorResponse{
Error: Error{
Message: "invalid input",
Type: "invalid_request_error",
},
},
},
}
@@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
tc.Setup(t, req)
t.Run(tc.name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
var errResp ErrorResponse
if resp.Code != http.StatusOK {
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
t.Fatal(err)
}
}
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
t.Fatal("requests did not match")
}
if !reflect.DeepEqual(tc.err, errResp) {
t.Fatal("errors did not match")
}
capturedRequest = nil
})
}
}
func TestMiddlewareResponses(t *testing.T) {
func TestListMiddleware(t *testing.T) {
type testCase struct {
Name string
Method string
Path string
TestPath string
Handler func() gin.HandlerFunc
Endpoint func(c *gin.Context)
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
Name: "list handler",
Method: http.MethodGet,
Path: "/api/tags",
TestPath: "/api/tags",
Handler: ListMiddleware,
Endpoint: func(c *gin.Context) {
name: "list handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{
Models: []api.ListModelResponse{
{
Name: "Test Model",
Name: "test-model",
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
},
},
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err)
}
if listResp.Object != "list" {
t.Fatalf("expected list, got %s", listResp.Object)
}
if len(listResp.Data) != 1 {
t.Fatalf("expected 1, got %d", len(listResp.Data))
}
if listResp.Data[0].Id != "Test Model" {
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
}
},
resp: `{
"object": "list",
"data": [
{
"id": "test-model",
"object": "model",
"created": 1686935002,
"owned_by": "library"
}
]
}`,
},
{
Name: "retrieve model",
Method: http.MethodGet,
Path: "/api/show/:model",
TestPath: "/api/show/test-model",
Handler: RetrieveMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
})
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
var retrieveResp Model
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
t.Fatal(err)
}
if retrieveResp.Object != "model" {
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
}
if retrieveResp.Id != "test-model" {
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
}
name: "list handler empty output",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ListResponse{})
},
resp: `{
"object": "list",
"data": null
}`,
},
}
gin.SetMode(gin.TestMode)
router := gin.New()
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
router = gin.New()
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, tc.Endpoint)
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
router := gin.New()
router.Use(ListMiddleware())
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
assert.Equal(t, http.StatusOK, resp.Code)
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
tc.Expected(t, resp)
})
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}
func TestRetrieveMiddleware(t *testing.T) {
type testCase struct {
name string
endpoint func(c *gin.Context)
resp string
}
testCases := []testCase{
{
name: "retrieve handler",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusOK, api.ShowResponse{
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
})
},
resp: `{
"id":"test-model",
"object":"model",
"created":1686935002,
"owned_by":"library"}
`,
},
{
name: "retrieve handler error forwarding",
endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
},
resp: `{
"error": {
"code": null,
"message": "model not found",
"param": null,
"type": "api_error"
}
}`,
},
}
gin.SetMode(gin.TestMode)
for _, tc := range testCases {
router := gin.New()
router.Use(RetrieveMiddleware())
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
var expected, actual map[string]any
err := json.Unmarshal([]byte(tc.resp), &expected)
if err != nil {
t.Fatalf("failed to unmarshal expected response: %v", err)
}
err = json.Unmarshal(resp.Body.Bytes(), &actual)
if err != nil {
t.Fatalf("failed to unmarshal actual response: %v", err)
}
if !reflect.DeepEqual(expected, actual) {
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
}
}
}

View File

@@ -1,137 +0,0 @@
package recorder
import (
"encoding/binary"
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/sys/unix"
"golang.org/x/term"
"github.com/gordonklaus/portaudio"
)
const (
sampleRate = 16000
numChannels = 1
bitsPerSample = 16
)
func RecordAudio(f *os.File) error {
fmt.Print("Recording. Press any key to stop.\n\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
portaudio.Initialize()
defer portaudio.Terminate()
in := make([]int16, 64)
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
if err != nil {
return err
}
defer stream.Close()
err = stream.Start()
if err != nil {
return err
}
// Write WAV header with placeholder sizes
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
var totalSamples uint32
// Set up terminal input reading
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return err
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Create a channel to handle the stop signal
stop := make(chan struct{})
go func() {
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
if err != nil {
fmt.Println("Error reading from stdin:", err)
return
}
// Send signal to stop recording
stop <- struct{}{}
}()
loop:
for {
err = stream.Read()
if err != nil {
return err
}
err = binary.Write(f, binary.LittleEndian, in)
if err != nil {
return err
}
totalSamples += uint32(len(in))
select {
case <-stop:
break loop
case <-sig:
break loop
default:
}
}
err = stream.Stop()
if err != nil {
return err
}
// Update WAV header with actual sizes
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
return nil
}
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
subchunk1Size := 16
audioFormat := 1
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
blockAlign := numChannels * (bitsPerSample / 8)
// Write the RIFF header
f.Write([]byte("RIFF"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
f.Write([]byte("WAVE"))
// Write the fmt subchunk
f.Write([]byte("fmt "))
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
binary.Write(f, binary.LittleEndian, uint16(numChannels))
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
binary.Write(f, binary.LittleEndian, uint32(byteRate))
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
// Write the data subchunk header
f.Write([]byte("data"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
}
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
// Seek to the start of the file and write updated sizes
f.Seek(4, 0)
binary.Write(f, binary.LittleEndian, uint32(fileSize))
f.Seek(40, 0)
binary.Write(f, binary.LittleEndian, uint32(dataSize))
}

View File

@@ -209,15 +209,15 @@ install_cuda_driver_yum() {
case $PACKAGE_MANAGER in
yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;;
dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
@@ -245,8 +245,8 @@ install_cuda_driver_yum() {
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() {
status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb
else
error $CUDA_REPO_ERR_MSG
fi

View File

@@ -216,9 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err
}
defer file.Close()
if err := setSparse(file); err != nil {
return err
}
setSparse(file)
_ = file.Truncate(b.Total)
@@ -235,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errors.New("maxium redirects exceeded (10) for directURL")
return errors.New("maximum redirects exceeded (10) for directURL")
}
// if the hostname is the same, allow the redirect

View File

@@ -373,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
var messages []*api.Message
parameters := make(map[string]any)
var layers []*Layer
var layers []Layer
for _, c := range modelfile.Commands {
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
@@ -499,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
if layer.MediaType != mediatype {
return false
}
@@ -545,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
}
var err2 error
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
switch layer.MediaType {
case "application/vnd.ollama.image.message":
// if there are new messages, remove the inherited ones
@@ -625,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
return err
}
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
if err != nil {
return err
}
for _, layer := range append(layers, layer) {
for _, layer := range append(layers, configLayer) {
if layer.status != "" {
fn(api.ProgressResponse{Status: layer.status})
}
@@ -639,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
old, _ := ParseNamedManifest(name)
fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, layer, layers); err != nil {
if err := WriteManifest(name, configLayer, layers); err != nil {
return err
}
@@ -839,10 +839,10 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return err
}
var layers []*Layer
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
}
for _, layer := range layers {
@@ -911,10 +911,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
return fmt.Errorf("pull model manifest: %s", err)
}
var layers []*Layer
var layers []Layer
layers = append(layers, manifest.Layers...)
if manifest.Config.Digest != "" {
layers = append(layers, &manifest.Config)
layers = append(layers, manifest.Config)
}
skipVerify := make(map[string]bool)

View File

@@ -16,15 +16,15 @@ type Layer struct {
status string
}
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
blobs, err := GetBlobsPath("")
if err != nil {
return nil, err
return Layer{}, err
}
temp, err := os.CreateTemp(blobs, "sha256-")
if err != nil {
return nil, err
return Layer{}, err
}
defer temp.Close()
defer os.Remove(temp.Name())
@@ -32,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
sha256sum := sha256.New()
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
if err != nil {
return nil, err
return Layer{}, err
}
if err := temp.Close(); err != nil {
return nil, err
return Layer{}, err
}
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
return Layer{}, err
}
status := "using existing layer"
if _, err := os.Stat(blob); err != nil {
status = "creating new layer"
if err := os.Rename(temp.Name(), blob); err != nil {
return nil, err
return Layer{}, err
}
}
return &Layer{
return Layer{
MediaType: mediatype,
Digest: digest,
Size: n,
@@ -61,22 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
}, nil
}
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
if digest == "" {
return nil, errors.New("creating new layer from layer with empty digest")
return Layer{}, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
return Layer{}, err
}
fi, err := os.Stat(blob)
if err != nil {
return nil, err
return Layer{}, err
}
return &Layer{
return Layer{
MediaType: mediatype,
Digest: digest,
Size: fi.Size(),
@@ -109,7 +109,7 @@ func (l *Layer) Remove() error {
}
for _, m := range ms {
for _, layer := range append(m.Layers, &m.Config) {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil

View File

@@ -14,10 +14,10 @@ import (
)
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Layer `json:"config"`
Layers []*Layer `json:"layers"`
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config Layer `json:"config"`
Layers []Layer `json:"layers"`
filepath string
fi os.FileInfo
@@ -25,7 +25,7 @@ type Manifest struct {
}
func (m *Manifest) Size() (size int64) {
for _, layer := range append(m.Layers, &m.Config) {
for _, layer := range append(m.Layers, m.Config) {
size += layer.Size
}
@@ -46,7 +46,7 @@ func (m *Manifest) Remove() error {
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, &m.Config) {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
@@ -95,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return &m, nil
}
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
manifests, err := GetManifestPath()
if err != nil {
return err
@@ -115,7 +115,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
m := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: *config,
Config: config,
Layers: layers,
}

View File

@@ -26,7 +26,7 @@ import (
var intermediateBlobs map[string]string = make(map[string]string)
type layerGGML struct {
*Layer
Layer
*llm.GGML
}

View File

@@ -10,23 +10,20 @@ import (
"io"
"log/slog"
"math"
"math/rand"
"mime/multipart"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/sync/errgroup"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
@@ -109,186 +106,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil
}
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
var modelPath string
if speech.Model == "" {
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
} else {
modelPath = speech.Model
}
// default to 5 minutes
var sessionDuration time.Duration
if speech.KeepAlive != nil {
sessionDuration = speech.KeepAlive.Duration
} else {
sessionDuration = 5 * time.Minute
}
s.sched.whisperMu.Lock()
if s.sched.whisperLoaded[modelPath] != nil {
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
portCh <- *s.sched.whisperLoaded[modelPath]
// Renew the expiration time
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
return
}
whisperServer := "/Users/royhan-ollama/.ollama/server"
// Find an available port for whisper
port := 0
params := []string{}
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
slog.Debug("ResolveTCPAddr failed")
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
cmd := exec.Command(whisperServer, finalParams...)
slog.Info("starting whisper server", "cmd", cmd.String())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
slog.Error("failed to start whisper server", "error", err)
errCh <- err
return
}
// Wait for server connection
retries := 25
var connErr error
for range retries {
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
if err == nil {
conn.Close()
connErr = nil
break
}
connErr = err
}
if connErr != nil {
slog.Error("failed to connect to whisper server", "error", connErr)
errCh <- connErr
return
}
portCh <- port
s.sched.whisperLoaded[modelPath] = &port
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
// Wait for the whisper server to exit
defer func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.sched.whisperMu.Lock()
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
slog.Info("exiting whisper server")
delete(s.sched.whisperLoaded, modelPath)
delete(s.sched.whisperExpiresAt, modelPath)
s.sched.whisperMu.Unlock()
if err := cmd.Process.Kill(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
slog.Debug("whisper server stopped")
return
}
s.sched.whisperMu.Unlock()
}
}()
}
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
// Open the file
file, err := os.Open(filePath)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
return nil, err
}
defer file.Close()
// Create a buffer to hold the multipart form data
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
// Add the file to the multipart form
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
return nil, err
}
// Add other fields to the form
if err := writer.WriteField("temperature", "0.0"); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
return nil, err
}
// Close the writer to finalize the multipart form
if err := writer.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
return nil, err
}
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
return nil, err
}
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
return nil, err
}
var w api.WhisperCompletion
if err := json.Unmarshal(body, &w); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
return nil, err
}
if w.Error != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": w.Error})
return nil, fmt.Errorf(w.Error)
}
return &w, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@@ -313,40 +130,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
if req.Speech != nil {
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req.Speech)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
w, err := whisperInference(c, req.Speech.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return
}
if req.Speech.Transcribe {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: w.Text,
Done: true,
DoneReason: "transcribe",
})
return
}
req.Prompt += "\n" + w.Text
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -564,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
return
}
var count int
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
@@ -586,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
}
}
count += len(tokens)
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
var g errgroup.Group
embeddings := make([][]float32, len(input))
for i, text := range input {
g.Go(func() error {
embedding, err := r.Embedding(c.Request.Context(), text)
if err != nil {
return err
}
embeddings[i] = normalize(embedding)
return nil
})
}
for i, e := range embeddings.Embedding {
embeddings.Embedding[i] = normalize(e)
if err := g.Wait(); err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
return
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings.Embedding,
Embeddings: embeddings,
TotalDuration: time.Since(checkpointStart),
LoadDuration: checkpointLoaded.Sub(checkpointStart),
PromptEvalCount: embeddings.PromptEvalCount,
PromptEvalCount: count,
}
c.JSON(http.StatusOK, resp)
}
@@ -648,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return
}
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
embedding := make([]float64, len(embeddings.Embedding[0]))
for i, v := range embeddings.Embedding[0] {
embedding[i] = float64(v)
var e []float64
for _, v := range embedding {
e = append(e, float64(v))
}
resp := api.EmbeddingResponse{
Embedding: embedding,
Embedding: e,
}
c.JSON(http.StatusOK, resp)
}
@@ -1249,6 +1043,11 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
if addr, err := netip.ParseAddr(host); err == nil {
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
if c.Request.Method == http.MethodOptions {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
return
}
@@ -1280,6 +1079,7 @@ func (s *Server) GenerateRoutes() http.Handler {
config.AllowOrigins = envconfig.Origins()
r := gin.Default()
r.HandleMethodNotAllowed = true
r.Use(
cors.New(config),
allowedHostsMiddleware(s.addr),
@@ -1514,37 +1314,6 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
slog.Info("processing audio")
if req == nil {
req = &api.WhisperRequest{}
}
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return err
}
// could parallelize this
for i, msg := range msgs {
if msg.Audio != "" {
w, err := whisperInference(c, msg.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return err
}
msgs[i].Content += "\n" + w.Text
}
}
return nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
@@ -1589,13 +1358,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
if req.Speech != nil || req.RunSpeech {
if err := processAudio(c, s, msgs, req.Speech); err != nil {
slog.Error("failed to process audio", "error", err)
return
}
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
}
// create a manifest with duplicate layers
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
if err := WriteManifest(n, config, []Layer{config}); err != nil {
t.Fatal(err)
}

View File

@@ -46,10 +46,6 @@ type Scheduler struct {
getGpuFn func() gpu.GpuInfoList
getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration
whisperLoaded map[string]*int
whisperExpiresAt map[string]time.Time
whisperMu sync.Mutex
}
// Default automatic value for number of models we allow per GPU
@@ -67,17 +63,15 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue()
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
whisperLoaded: make(map[string]*int),
whisperExpiresAt: make(map[string]time.Time),
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
}
sched.loadFn = sched.load
return sched
@@ -116,10 +110,6 @@ func (s *Scheduler) Run(ctx context.Context) {
go func() {
s.processCompleted(ctx)
}()
// go func() {
// could clean up whisper servers in init thread
// }
}
func (s *Scheduler) processPending(ctx context.Context) {

View File

@@ -708,8 +708,8 @@ type mockLlm struct {
pingResp error
waitResp error
completionResp error
embedResp *llm.EmbedResponse
embedRespErr error
embeddingResp []float32
embeddingRespErr error
tokenizeResp []int
tokenizeRespErr error
detokenizeResp string
@@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
return s.completionResp
}
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
return s.embedResp, s.embedRespErr
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
return s.embeddingResp, s.embeddingRespErr
}
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {

View File

@@ -4,6 +4,5 @@ package server
import "os"
func setSparse(file *os.File) error {
return nil
func setSparse(*os.File) {
}

View File

@@ -6,8 +6,9 @@ import (
"golang.org/x/sys/windows"
)
func setSparse(file *os.File) error {
return windows.DeviceIoControl(
func setSparse(file *os.File) {
// exFat (and other FS types) don't support sparse files, so ignore errors
windows.DeviceIoControl( //nolint:errcheck
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
nil, 0,
nil, 0,

View File

@@ -26,7 +26,7 @@ import (
var blobUploadManager sync.Map
type blobUpload struct {
*Layer
Layer
Total int64
Completed atomic.Int64
@@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
p.written = 0
}
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
requestURL := mp.BaseURL()
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)