Compare commits
30 Commits
royh/whisp
...
royh/ep-me
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
781585d9bd | ||
|
|
b84a54be05 | ||
|
|
01d544d373 | ||
|
|
1dc3ef3aa9 | ||
|
|
8aac22438e | ||
|
|
15c2d8fe14 | ||
|
|
25906d72d1 | ||
|
|
023451ce47 | ||
|
|
9b53e39d8e | ||
|
|
97fae2df95 | ||
|
|
160d9d4900 | ||
|
|
d4e6407464 | ||
|
|
b7f7d8cd15 | ||
|
|
2fa1db4345 | ||
|
|
71b0945fc6 | ||
|
|
5bca2e60a7 | ||
|
|
67472e0e89 | ||
|
|
e9aa5117c4 | ||
|
|
2473bdba5e | ||
|
|
7d1c0047fa | ||
|
|
7b61eba471 | ||
|
|
7edaf6e7e8 | ||
|
|
97ec8cfd4e | ||
|
|
5b3a21b578 | ||
|
|
ad0c19dde4 | ||
|
|
ce67706037 | ||
|
|
04210aa6dd | ||
|
|
43f9d92008 | ||
|
|
ed6c8bfe57 | ||
|
|
df3802a65f |
3
.gitattributes
vendored
3
.gitattributes
vendored
@@ -1,2 +1,3 @@
|
||||
llm/ext_server/* linguist-vendored
|
||||
* text eol=lf
|
||||
* text=auto
|
||||
*.go text eol=lf
|
||||
|
||||
5
.gitmodules
vendored
5
.gitmodules
vendored
@@ -1,7 +1,4 @@
|
||||
[submodule "llama.cpp"]
|
||||
path = llm/llama.cpp
|
||||
url = https://github.com/ggerganov/llama.cpp.git
|
||||
shallow = true
|
||||
[submodule "llm/whisper.cpp"]
|
||||
path = llm/whisper.cpp
|
||||
url = git@github.com:ggerganov/whisper.cpp.git
|
||||
shallow = true
|
||||
@@ -325,6 +325,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [tlm](https://github.com/yusufcanb/tlm)
|
||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||
- [gollama](https://github.com/sammcj/gollama)
|
||||
- [Ollama eBook Summary](https://github.com/cognitivetech/ollama-ebook-summary/)
|
||||
|
||||
### Database
|
||||
|
||||
|
||||
19
api/types.go
19
api/types.go
@@ -36,13 +36,6 @@ func (e StatusError) Error() string {
|
||||
// ImageData represents the raw binary data of an image file.
|
||||
type ImageData []byte
|
||||
|
||||
type WhisperRequest struct {
|
||||
Model string `json:"model,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
Transcribe bool `json:"transcribe,omitempty"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
}
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
@@ -87,8 +80,6 @@ type GenerateRequest struct {
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
Speech *WhisperRequest `json:"speech,omitempty"`
|
||||
}
|
||||
|
||||
// ChatRequest describes a request sent by [Client.Chat].
|
||||
@@ -114,10 +105,6 @@ type ChatRequest struct {
|
||||
|
||||
// Options lists model-specific options.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
|
||||
Speech *WhisperRequest `json:"speech,omitempty"`
|
||||
|
||||
RunSpeech bool `json:"run_speech,omitempty"`
|
||||
}
|
||||
|
||||
type Tools []Tool
|
||||
@@ -140,7 +127,6 @@ type Message struct {
|
||||
Content string `json:"content"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
|
||||
Audio string `json:"audio,omitempty"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(b []byte) error {
|
||||
@@ -464,11 +450,6 @@ type GenerateResponse struct {
|
||||
Metrics
|
||||
}
|
||||
|
||||
type WhisperCompletion struct {
|
||||
Text string `json:"text"`
|
||||
Error string `json:"error,omitempty"`
|
||||
}
|
||||
|
||||
// ModelDetails provides details about a model.
|
||||
type ModelDetails struct {
|
||||
ParentModel string `json:"parent_model"`
|
||||
|
||||
39
cmd/cmd.go
39
cmd/cmd.go
@@ -38,7 +38,6 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/server"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -381,14 +380,6 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
}
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if speech {
|
||||
return generateInteractiveAudio(cmd, opts)
|
||||
}
|
||||
return generateInteractive(cmd, opts)
|
||||
}
|
||||
return generate(cmd, opts)
|
||||
@@ -871,7 +862,6 @@ type runOptions struct {
|
||||
Options map[string]interface{}
|
||||
MultiModal bool
|
||||
KeepAlive *api.Duration
|
||||
Audio bool
|
||||
}
|
||||
|
||||
type displayResponseState struct {
|
||||
@@ -980,10 +970,6 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
|
||||
req.KeepAlive = opts.KeepAlive
|
||||
}
|
||||
|
||||
if opts.Audio {
|
||||
req.RunSpeech = true
|
||||
}
|
||||
|
||||
if err := client.Chat(cancelCtx, req, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil, nil
|
||||
@@ -1069,30 +1055,6 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
KeepAlive: opts.KeepAlive,
|
||||
}
|
||||
|
||||
speech, err := cmd.Flags().GetBool("speech")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
if speech {
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
fmt.Print("Speech Mode\n\n")
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
request.Speech = &api.WhisperRequest{
|
||||
Audio: tempFile.Name(),
|
||||
}
|
||||
}
|
||||
if err := client.Generate(ctx, &request, fn); err != nil {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return nil
|
||||
@@ -1300,7 +1262,6 @@ func NewCLI() *cobra.Command {
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("speech", false, "Speech to text mode")
|
||||
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")
|
||||
|
||||
@@ -20,7 +20,6 @@ import (
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/readline"
|
||||
"github.com/ollama/ollama/recorder"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
)
|
||||
|
||||
@@ -52,40 +51,6 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
|
||||
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
|
||||
}
|
||||
|
||||
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
|
||||
for {
|
||||
p := progress.NewProgress(os.Stderr)
|
||||
spinner := progress.NewSpinner("")
|
||||
p.Add("", spinner)
|
||||
|
||||
// create temp wav file with the recorder package
|
||||
tempFile, err := os.CreateTemp("", "recording-*.wav")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
err = recorder.RecordAudio(tempFile)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p.StopAndClear()
|
||||
|
||||
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
|
||||
opts.Audio = true
|
||||
opts.Messages = append(opts.Messages, newMessage)
|
||||
|
||||
assistant, err := chat(cmd, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if assistant != nil {
|
||||
opts.Messages = append(opts.Messages, *assistant)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
usage := func() {
|
||||
fmt.Fprintln(os.Stderr, "Available Commands:")
|
||||
|
||||
@@ -669,7 +669,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "mistral",
|
||||
"model": "llama3.1",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -708,7 +708,7 @@ curl http://localhost:11434/api/chat -d '{
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "mistral:7b-instruct-v0.3-q4_K_M",
|
||||
"model": "llama3.1",
|
||||
"created_at": "2024-07-22T20:33:28.123648Z",
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
@@ -1175,7 +1175,10 @@ curl http://localhost:11434/api/embed -d '{
|
||||
"embeddings": [[
|
||||
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
|
||||
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
|
||||
]]
|
||||
]],
|
||||
"total_duration": 14143917,
|
||||
"load_duration": 1019500,
|
||||
"prompt_eval_count": 8
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
# Speech to Text Prototype
|
||||
|
||||
### To run
|
||||
`make {/path/to/whisper.cpp/server}`
|
||||
- replace `whisperServer` in `routes.go` with path to server
|
||||
|
||||
## CLI
|
||||
`./ollama run llama3 [PROMPT] --speech`
|
||||
- processes voice audio with the provided prompt
|
||||
|
||||
`./ollama run llama3 --speech`
|
||||
- enters interactive mode for continuous voice chat
|
||||
- TODO: fix exiting interactive mode
|
||||
|
||||
Notes: uses default model
|
||||
|
||||
|
||||
## api/generate
|
||||
### Request fields
|
||||
- `speech` (required):
|
||||
- `audio` (required): path to audio file
|
||||
- `model` (optional): path to whisper model, uses default if null
|
||||
- `transcribe` (optional): if true, will transcribe and return the audio file
|
||||
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
|
||||
- `prompt` (optional): if not null, passed in with the transcribed audio
|
||||
|
||||
#### Transcription
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
|
||||
"transcribe": true,
|
||||
"keep_alive": "1m"
|
||||
},
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
||||
|
||||
#### Response Generation
|
||||
```
|
||||
curl http://localhost:11434/api/generate -d '{
|
||||
"model": "llama3",
|
||||
"prompt": "What do you think about this quote?",
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
|
||||
"keep_alive": "1m"
|
||||
},
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
||||
|
||||
## api/chat
|
||||
### Request fields
|
||||
- `model` (required): language model to chat with
|
||||
- `speech` (optional):
|
||||
- `model` (optional): path to whisper model, uses default if null
|
||||
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
|
||||
- `run_speech` (optional): either this flag must be true or `speech` must be passed in for speech mode to run
|
||||
- `messages`/`message`/`audio` (required): path to audio file
|
||||
|
||||
```
|
||||
curl http://localhost:11434/api/chat -d '{
|
||||
"model": "llama3",
|
||||
"speech": {
|
||||
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
|
||||
"keep_alive": "10m"
|
||||
},
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a Canadian Nationalist"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What do you think about this quote?",
|
||||
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
|
||||
}
|
||||
],
|
||||
"stream": false
|
||||
}' | jq
|
||||
```
|
||||
1
go.mod
1
go.mod
@@ -19,7 +19,6 @@ require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/google/go-cmp v0.6.0
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c
|
||||
|
||||
2
go.sum
2
go.sum
@@ -115,8 +115,6 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
|
||||
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
|
||||
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
|
||||
@@ -49,13 +49,9 @@ func PayloadsDir() (string, error) {
|
||||
}
|
||||
|
||||
// Track our pid so we can clean up orphaned tmpdirs
|
||||
pidFilePath := filepath.Join(tmpDir, "ollama.pid")
|
||||
pidFile, err := os.OpenFile(pidFilePath, os.O_CREATE|os.O_TRUNC|os.O_WRONLY, os.ModePerm)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if _, err := pidFile.Write([]byte(strconv.Itoa(os.Getpid()))); err != nil {
|
||||
return "", err
|
||||
n := filepath.Join(tmpDir, "ollama.pid")
|
||||
if err := os.WriteFile(n, []byte(strconv.Itoa(os.Getpid())), 0o644); err != nil {
|
||||
return "", fmt.Errorf("failed to write pid file %s: %w", n, err)
|
||||
}
|
||||
|
||||
// We create a distinct subdirectory for payloads within the tmpdir
|
||||
@@ -67,37 +63,44 @@ func PayloadsDir() (string, error) {
|
||||
|
||||
// Best effort to clean up prior tmpdirs
|
||||
func cleanupTmpDirs() {
|
||||
dirs, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*"))
|
||||
matches, err := filepath.Glob(filepath.Join(os.TempDir(), "ollama*", "ollama.pid"))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, d := range dirs {
|
||||
info, err := os.Stat(d)
|
||||
if err != nil || !info.IsDir() {
|
||||
|
||||
for _, match := range matches {
|
||||
raw, err := os.ReadFile(match)
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("not a ollama runtime directory, skipping", "path", match)
|
||||
continue
|
||||
}
|
||||
raw, err := os.ReadFile(filepath.Join(d, "ollama.pid"))
|
||||
if err != nil {
|
||||
slog.Warn("failed to read ollama.pid", "path", d, "error", err)
|
||||
// No pid, ignore this tmpdir
|
||||
} else if err != nil {
|
||||
slog.Warn("could not read ollama.pid, skipping", "path", match, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
pid, err := strconv.Atoi(string(raw))
|
||||
if err != nil {
|
||||
slog.Warn("failed to parse pid", "path", d, "error", err)
|
||||
slog.Warn("invalid pid, skipping", "path", match, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
slog.Warn("found running ollama", "pid", pid, "path", d)
|
||||
// Another running ollama, ignore this tmpdir
|
||||
p, err := os.FindProcess(pid)
|
||||
if err == nil && !errors.Is(p.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
slog.Warn("process still running, skipping", "pid", pid, "path", match)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := os.Remove(d); err != nil {
|
||||
slog.Warn("unable to cleanup stale tmpdir", "path", d, "error", err)
|
||||
if err := os.Remove(match); err != nil {
|
||||
slog.Warn("could not cleanup stale pidfile", "path", match, "error", err)
|
||||
}
|
||||
|
||||
runners := filepath.Join(filepath.Dir(match), "runners")
|
||||
if err := os.RemoveAll(runners); err != nil {
|
||||
slog.Warn("could not cleanup stale runners", "path", runners, "error", err)
|
||||
}
|
||||
|
||||
if err := os.Remove(filepath.Dir(match)); err != nil {
|
||||
slog.Warn("could not cleanup stale tmpdir", "path", filepath.Dir(match), "error", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
61
gpu/gpu.go
61
gpu/gpu.go
@@ -305,38 +305,41 @@ func GetGPUInfo() GpuInfoList {
|
||||
// Intel
|
||||
if envconfig.IntelGPU() {
|
||||
oHandles = initOneAPIHandles()
|
||||
// On windows we bundle the oneapi library one level above the runner dir
|
||||
depPath = ""
|
||||
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
||||
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
|
||||
}
|
||||
if oHandles != nil && oHandles.oneapi != nil {
|
||||
|
||||
for d := range oHandles.oneapi.num_drivers {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
||||
continue
|
||||
// On windows we bundle the oneapi library one level above the runner dir
|
||||
depPath = ""
|
||||
if runtime.GOOS == "windows" && envconfig.RunnersDir() != "" {
|
||||
depPath = filepath.Join(filepath.Dir(envconfig.RunnersDir()), "oneapi")
|
||||
}
|
||||
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
||||
for i := range devCount {
|
||||
gpuInfo := OneapiGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "oneapi",
|
||||
},
|
||||
driverIndex: int(d),
|
||||
gpuIndex: int(i),
|
||||
|
||||
for d := range oHandles.oneapi.num_drivers {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
||||
continue
|
||||
}
|
||||
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
||||
for i := range devCount {
|
||||
gpuInfo := OneapiGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "oneapi",
|
||||
},
|
||||
driverIndex: int(d),
|
||||
gpuIndex: int(i),
|
||||
}
|
||||
// TODO - split bootstrapping from updating free memory
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.DependencyPath = depPath
|
||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||
}
|
||||
// TODO - split bootstrapping from updating free memory
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.DependencyPath = depPath
|
||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
42
llm/ext_server/server.cpp
vendored
42
llm/ext_server/server.cpp
vendored
@@ -1223,9 +1223,7 @@ struct llama_server_context
|
||||
|
||||
res.result_json = json
|
||||
{
|
||||
{"id", res.id},
|
||||
{"embedding", std::vector<float>(embd, embd + n_embd)},
|
||||
{"timings", slot.get_formated_timings()},
|
||||
};
|
||||
}
|
||||
}
|
||||
@@ -3194,41 +3192,17 @@ int main(int argc, char **argv) {
|
||||
prompt = "";
|
||||
}
|
||||
|
||||
if (prompt.size() == 1) {
|
||||
prompt = prompt[0];
|
||||
}
|
||||
|
||||
// create and queue the task
|
||||
json responses;
|
||||
{
|
||||
const int id_task = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(id_task);
|
||||
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
|
||||
const int task_id = llama.queue_tasks.get_new_id();
|
||||
llama.queue_results.add_waiting_task_id(task_id);
|
||||
llama.request_completion(task_id, {{"prompt", prompt}}, true, -1);
|
||||
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(id_task);
|
||||
llama.queue_results.remove_waiting_task_id(id_task);
|
||||
if (result.error) {
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
}
|
||||
// get the result
|
||||
task_result result = llama.queue_results.recv(task_id);
|
||||
llama.queue_results.remove_waiting_task_id(task_id);
|
||||
|
||||
responses = result.result_json.value("results", std::vector<json>{result.result_json});
|
||||
std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) {
|
||||
return a["id"] < b["id"];
|
||||
});
|
||||
|
||||
json embeddings = json::array();
|
||||
|
||||
int prompt_n = 0;
|
||||
for (auto & elem : responses) {
|
||||
embeddings.push_back(elem.at("embedding"));
|
||||
prompt_n += elem.at("timings").at("prompt_n").get<int>();
|
||||
}
|
||||
|
||||
// send the result
|
||||
json embedding_res = json{{"embedding", embeddings}, {"prompt_n", prompt_n}};
|
||||
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
|
||||
}
|
||||
// send the result
|
||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
||||
});
|
||||
|
||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
||||
|
||||
@@ -33,7 +33,7 @@ type LlamaServer interface {
|
||||
Ping(ctx context.Context) error
|
||||
WaitUntilRunning(ctx context.Context) error
|
||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||
Embed(ctx context.Context, input []string) (*EmbedResponse, error)
|
||||
Embedding(ctx context.Context, input string) ([]float32, error)
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
@@ -125,8 +125,9 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
}
|
||||
}
|
||||
|
||||
// On linux, over-allocating CPU memory will almost always result in an error
|
||||
if runtime.GOOS == "linux" {
|
||||
// On linux and windows, over-allocating CPU memory will almost always result in an error
|
||||
// Darwin has fully dynamic swap so has no direct concept of free swap space
|
||||
if runtime.GOOS != "darwin" {
|
||||
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
|
||||
available := systemFreeMemory + systemSwapFreeMemory
|
||||
if systemMemoryRequired > available {
|
||||
@@ -882,24 +883,20 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
return nil
|
||||
}
|
||||
|
||||
type EmbedRequest struct {
|
||||
Content []string `json:"content"`
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbedResponse struct {
|
||||
Embedding [][]float32 `json:"embedding"`
|
||||
PromptEvalCount int `json:"prompt_n"`
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float32 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse, error) {
|
||||
// each input will use a slot, so we need to acquire the semaphore for
|
||||
// the number of inputs up to numParallel
|
||||
slots := int64(min(len(input), s.numParallel))
|
||||
if err := s.sem.Acquire(ctx, slots); err != nil {
|
||||
func (s *llmServer) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||
slog.Error("Failed to acquire semaphore", "error", err)
|
||||
return nil, err
|
||||
}
|
||||
defer s.sem.Release(slots)
|
||||
defer s.sem.Release(1)
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatusRetry(ctx)
|
||||
@@ -909,18 +906,18 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(EmbedRequest{Content: input})
|
||||
data, err := json.Marshal(EmbeddingRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
resp, err := http.DefaultClient.Do(r)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
@@ -936,12 +933,12 @@ func (s *llmServer) Embed(ctx context.Context, input []string) (*EmbedResponse,
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var e EmbedResponse
|
||||
var e EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &e); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return &e, nil
|
||||
return e.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
|
||||
@@ -26,6 +26,7 @@ var errorPrefixes = []string{
|
||||
"cudaMalloc failed",
|
||||
"\"ERR\"",
|
||||
"error loading model",
|
||||
"GGML_ASSERT",
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
|
||||
Submodule llm/whisper.cpp deleted from 6739eb83c3
@@ -7,27 +7,22 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const (
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
imageURL = prefix + image
|
||||
prefix = `data:image/jpeg;base64,`
|
||||
image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
|
||||
)
|
||||
|
||||
func prepareRequest(req *http.Request, body any) {
|
||||
bodyBytes, _ := json.Marshal(body)
|
||||
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
var False = false
|
||||
|
||||
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
@@ -43,134 +38,136 @@ func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
|
||||
|
||||
func TestChatMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder)
|
||||
name string
|
||||
body string
|
||||
req api.ChatRequest
|
||||
err ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.ChatRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "chat handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: "Hello"}},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||
}
|
||||
name: "chat handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Hello"}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "chat handler with image content",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{
|
||||
Role: "user", Content: []map[string]any{
|
||||
{"type": "text", "text": "Hello"},
|
||||
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
|
||||
name: "chat handler with image content",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Hello"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "` + prefix + image + `"
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "Hello",
|
||||
},
|
||||
{
|
||||
Role: "user",
|
||||
Images: []api.ImageData{
|
||||
func() []byte {
|
||||
img, _ := base64.StdEncoding.DecodeString(image)
|
||||
return img
|
||||
}(),
|
||||
},
|
||||
},
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "chat handler with tools",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "What's the weather like in Paris Today?"},
|
||||
{"role": "assistant", "tool_calls": [{"id": "id", "type": "function", "function": {"name": "get_current_weather", "arguments": "{\"location\": \"Paris, France\", \"format\": \"celsius\"}"}}]}
|
||||
]
|
||||
}`,
|
||||
req: api.ChatRequest{
|
||||
Model: "test-model",
|
||||
Messages: []api.Message{
|
||||
{
|
||||
Role: "user",
|
||||
Content: "What's the weather like in Paris Today?",
|
||||
},
|
||||
{
|
||||
Role: "assistant",
|
||||
ToolCalls: []api.ToolCall{
|
||||
{
|
||||
Function: api.ToolCallFunction{
|
||||
Name: "get_current_weather",
|
||||
Arguments: map[string]interface{}{
|
||||
"location": "Paris, France",
|
||||
"format": "celsius",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusOK {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
|
||||
}
|
||||
|
||||
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
|
||||
|
||||
if req.Messages[1].Role != "user" {
|
||||
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
|
||||
}
|
||||
|
||||
if !bytes.Equal(req.Messages[1].Images[0], img) {
|
||||
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
|
||||
}
|
||||
},
|
||||
Options: map[string]any{
|
||||
"temperature": 1.0,
|
||||
"top_p": 1.0,
|
||||
},
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
|
||||
{
|
||||
Name: "chat handler with tools",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{
|
||||
{Role: "user", Content: "What's the weather like in Paris Today?"},
|
||||
{Role: "assistant", ToolCalls: []ToolCall{{
|
||||
ID: "id",
|
||||
Type: "function",
|
||||
Function: struct {
|
||||
Name string `json:"name"`
|
||||
Arguments string `json:"arguments"`
|
||||
}{
|
||||
Name: "get_current_weather",
|
||||
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
|
||||
},
|
||||
}}},
|
||||
},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != 200 {
|
||||
t.Fatalf("expected 200, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
|
||||
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
|
||||
}
|
||||
|
||||
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
|
||||
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
|
||||
}
|
||||
|
||||
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
|
||||
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "chat handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := ChatCompletionRequest{
|
||||
Model: "test-model",
|
||||
Messages: []Message{{Role: "user", Content: 2}},
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid message content type") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
name: "chat handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": 2}
|
||||
]
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "invalid message content type: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -185,16 +182,26 @@ func TestChatMiddleware(t *testing.T) {
|
||||
router.Handle(http.MethodPost, "/api/chat", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/chat", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
var errResp ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
@@ -202,71 +209,52 @@ func TestChatMiddleware(t *testing.T) {
|
||||
|
||||
func TestCompletionsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
|
||||
name string
|
||||
body string
|
||||
req api.GenerateRequest
|
||||
err ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.GenerateRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "completions handler",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
temp := float32(0.8)
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: &temp,
|
||||
Stop: []string{"\n", "stop"},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if req.Prompt != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Prompt)
|
||||
}
|
||||
|
||||
if req.Options["temperature"] != 1.6 {
|
||||
t.Fatalf("expected 1.6, got %f", req.Options["temperature"])
|
||||
}
|
||||
|
||||
stopTokens, ok := req.Options["stop"].([]any)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("expected stop tokens to be a list")
|
||||
}
|
||||
|
||||
if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
|
||||
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
|
||||
}
|
||||
|
||||
if req.Suffix != "suffix" {
|
||||
t.Fatalf("expected 'suffix', got %s", req.Suffix)
|
||||
}
|
||||
name: "completions handler",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": 0.8,
|
||||
"stop": ["\n", "stop"],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
req: api.GenerateRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Options: map[string]any{
|
||||
"frequency_penalty": 0.0,
|
||||
"presence_penalty": 0.0,
|
||||
"temperature": 1.6,
|
||||
"top_p": 1.0,
|
||||
"stop": []any{"\n", "stop"},
|
||||
},
|
||||
Suffix: "suffix",
|
||||
Stream: &False,
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "completions handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := CompletionRequest{
|
||||
Model: "test-model",
|
||||
Prompt: "Hello",
|
||||
Temperature: nil,
|
||||
Stop: []int{1, 2},
|
||||
Suffix: "suffix",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
name: "completions handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model",
|
||||
"prompt": "Hello",
|
||||
"temperature": null,
|
||||
"stop": [1, 2],
|
||||
"suffix": "suffix"
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "invalid type for 'stop' field: float64",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -281,15 +269,27 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
router.Handle(http.MethodPost, "/api/generate", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/generate", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
var errResp ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
@@ -298,78 +298,47 @@ func TestCompletionsMiddleware(t *testing.T) {
|
||||
|
||||
func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
|
||||
name string
|
||||
body string
|
||||
req api.EmbedRequest
|
||||
err ErrorResponse
|
||||
}
|
||||
|
||||
var capturedRequest *api.EmbedRequest
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "embed handler single input",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Input: "Hello",
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
if req.Input != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", req.Input)
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||
}
|
||||
name: "embed handler single input",
|
||||
body: `{
|
||||
"input": "Hello",
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: "Hello",
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "embed handler batch input",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Input: []string{"Hello", "World"},
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
input, ok := req.Input.([]any)
|
||||
|
||||
if !ok {
|
||||
t.Fatalf("expected input to be a list")
|
||||
}
|
||||
|
||||
if input[0].(string) != "Hello" {
|
||||
t.Fatalf("expected 'Hello', got %s", input[0])
|
||||
}
|
||||
|
||||
if input[1].(string) != "World" {
|
||||
t.Fatalf("expected 'World', got %s", input[1])
|
||||
}
|
||||
|
||||
if req.Model != "test-model" {
|
||||
t.Fatalf("expected 'test-model', got %s", req.Model)
|
||||
}
|
||||
name: "embed handler batch input",
|
||||
body: `{
|
||||
"input": ["Hello", "World"],
|
||||
"model": "test-model"
|
||||
}`,
|
||||
req: api.EmbedRequest{
|
||||
Input: []any{"Hello", "World"},
|
||||
Model: "test-model",
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "embed handler error forwarding",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
body := EmbedRequest{
|
||||
Model: "test-model",
|
||||
}
|
||||
prepareRequest(req, body)
|
||||
},
|
||||
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
|
||||
if resp.Code != http.StatusBadRequest {
|
||||
t.Fatalf("expected 400, got %d", resp.Code)
|
||||
}
|
||||
|
||||
if !strings.Contains(resp.Body.String(), "invalid input") {
|
||||
t.Fatalf("error was not forwarded")
|
||||
}
|
||||
name: "embed handler error forwarding",
|
||||
body: `{
|
||||
"model": "test-model"
|
||||
}`,
|
||||
err: ErrorResponse{
|
||||
Error: Error{
|
||||
Message: "invalid input",
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -384,116 +353,167 @@ func TestEmbeddingsMiddleware(t *testing.T) {
|
||||
router.Handle(http.MethodPost, "/api/embed", endpoint)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
|
||||
|
||||
tc.Setup(t, req)
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
req, _ := http.NewRequest(http.MethodPost, "/api/embed", strings.NewReader(tc.body))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
tc.Expected(t, capturedRequest, resp)
|
||||
var errResp ErrorResponse
|
||||
if resp.Code != http.StatusOK {
|
||||
if err := json.Unmarshal(resp.Body.Bytes(), &errResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
if capturedRequest != nil && !reflect.DeepEqual(tc.req, *capturedRequest) {
|
||||
t.Fatal("requests did not match")
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(tc.err, errResp) {
|
||||
t.Fatal("errors did not match")
|
||||
}
|
||||
|
||||
capturedRequest = nil
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestMiddlewareResponses(t *testing.T) {
|
||||
func TestListMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
Path string
|
||||
TestPath string
|
||||
Handler func() gin.HandlerFunc
|
||||
Endpoint func(c *gin.Context)
|
||||
Setup func(t *testing.T, req *http.Request)
|
||||
Expected func(t *testing.T, resp *httptest.ResponseRecorder)
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
Name: "list handler",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/tags",
|
||||
TestPath: "/api/tags",
|
||||
Handler: ListMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
name: "list handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{
|
||||
Models: []api.ListModelResponse{
|
||||
{
|
||||
Name: "Test Model",
|
||||
Name: "test-model",
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
},
|
||||
},
|
||||
})
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
var listResp ListCompletion
|
||||
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if listResp.Object != "list" {
|
||||
t.Fatalf("expected list, got %s", listResp.Object)
|
||||
}
|
||||
|
||||
if len(listResp.Data) != 1 {
|
||||
t.Fatalf("expected 1, got %d", len(listResp.Data))
|
||||
}
|
||||
|
||||
if listResp.Data[0].Id != "Test Model" {
|
||||
t.Fatalf("expected Test Model, got %s", listResp.Data[0].Id)
|
||||
}
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "test-model",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "library"
|
||||
}
|
||||
]
|
||||
}`,
|
||||
},
|
||||
{
|
||||
Name: "retrieve model",
|
||||
Method: http.MethodGet,
|
||||
Path: "/api/show/:model",
|
||||
TestPath: "/api/show/test-model",
|
||||
Handler: RetrieveMiddleware,
|
||||
Endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ShowResponse{
|
||||
ModifiedAt: time.Date(2024, 6, 17, 13, 45, 0, 0, time.UTC),
|
||||
})
|
||||
},
|
||||
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
|
||||
var retrieveResp Model
|
||||
if err := json.NewDecoder(resp.Body).Decode(&retrieveResp); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if retrieveResp.Object != "model" {
|
||||
t.Fatalf("Expected object to be model, got %s", retrieveResp.Object)
|
||||
}
|
||||
|
||||
if retrieveResp.Id != "test-model" {
|
||||
t.Fatalf("Expected id to be test-model, got %s", retrieveResp.Id)
|
||||
}
|
||||
name: "list handler empty output",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ListResponse{})
|
||||
},
|
||||
resp: `{
|
||||
"object": "list",
|
||||
"data": null
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
router := gin.New()
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
router = gin.New()
|
||||
router.Use(tc.Handler())
|
||||
router.Handle(tc.Method, tc.Path, tc.Endpoint)
|
||||
req, _ := http.NewRequest(tc.Method, tc.TestPath, nil)
|
||||
router := gin.New()
|
||||
router.Use(ListMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/tags", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/tags", nil)
|
||||
|
||||
if tc.Setup != nil {
|
||||
tc.Setup(t, req)
|
||||
}
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
assert.Equal(t, http.StatusOK, resp.Code)
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
tc.Expected(t, resp)
|
||||
})
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRetrieveMiddleware(t *testing.T) {
|
||||
type testCase struct {
|
||||
name string
|
||||
endpoint func(c *gin.Context)
|
||||
resp string
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
{
|
||||
name: "retrieve handler",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ShowResponse{
|
||||
ModifiedAt: time.Unix(int64(1686935002), 0).UTC(),
|
||||
})
|
||||
},
|
||||
resp: `{
|
||||
"id":"test-model",
|
||||
"object":"model",
|
||||
"created":1686935002,
|
||||
"owned_by":"library"}
|
||||
`,
|
||||
},
|
||||
{
|
||||
name: "retrieve handler error forwarding",
|
||||
endpoint: func(c *gin.Context) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "model not found"})
|
||||
},
|
||||
resp: `{
|
||||
"error": {
|
||||
"code": null,
|
||||
"message": "model not found",
|
||||
"param": null,
|
||||
"type": "api_error"
|
||||
}
|
||||
}`,
|
||||
},
|
||||
}
|
||||
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
for _, tc := range testCases {
|
||||
router := gin.New()
|
||||
router.Use(RetrieveMiddleware())
|
||||
router.Handle(http.MethodGet, "/api/show/:model", tc.endpoint)
|
||||
req, _ := http.NewRequest(http.MethodGet, "/api/show/test-model", nil)
|
||||
|
||||
resp := httptest.NewRecorder()
|
||||
router.ServeHTTP(resp, req)
|
||||
|
||||
var expected, actual map[string]any
|
||||
err := json.Unmarshal([]byte(tc.resp), &expected)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal expected response: %v", err)
|
||||
}
|
||||
|
||||
err = json.Unmarshal(resp.Body.Bytes(), &actual)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to unmarshal actual response: %v", err)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expected, actual) {
|
||||
t.Errorf("responses did not match\nExpected: %+v\nActual: %+v", expected, actual)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,137 +0,0 @@
|
||||
package recorder
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/gordonklaus/portaudio"
|
||||
)
|
||||
|
||||
const (
|
||||
sampleRate = 16000
|
||||
numChannels = 1
|
||||
bitsPerSample = 16
|
||||
)
|
||||
|
||||
func RecordAudio(f *os.File) error {
|
||||
fmt.Print("Recording. Press any key to stop.\n\n")
|
||||
|
||||
sig := make(chan os.Signal, 1)
|
||||
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
|
||||
|
||||
portaudio.Initialize()
|
||||
defer portaudio.Terminate()
|
||||
|
||||
in := make([]int16, 64)
|
||||
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer stream.Close()
|
||||
|
||||
err = stream.Start()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write WAV header with placeholder sizes
|
||||
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
|
||||
|
||||
var totalSamples uint32
|
||||
|
||||
// Set up terminal input reading
|
||||
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer term.Restore(int(os.Stdin.Fd()), oldState)
|
||||
|
||||
// Create a channel to handle the stop signal
|
||||
stop := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
|
||||
if err != nil {
|
||||
fmt.Println("Error reading from stdin:", err)
|
||||
return
|
||||
}
|
||||
// Send signal to stop recording
|
||||
stop <- struct{}{}
|
||||
}()
|
||||
|
||||
loop:
|
||||
for {
|
||||
err = stream.Read()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, in)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
totalSamples += uint32(len(in))
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
break loop
|
||||
case <-sig:
|
||||
break loop
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
err = stream.Stop()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Update WAV header with actual sizes
|
||||
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
|
||||
subchunk1Size := 16
|
||||
audioFormat := 1
|
||||
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
|
||||
blockAlign := numChannels * (bitsPerSample / 8)
|
||||
|
||||
// Write the RIFF header
|
||||
f.Write([]byte("RIFF"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
|
||||
f.Write([]byte("WAVE"))
|
||||
|
||||
// Write the fmt subchunk
|
||||
f.Write([]byte("fmt "))
|
||||
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
|
||||
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
|
||||
binary.Write(f, binary.LittleEndian, uint16(numChannels))
|
||||
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
|
||||
binary.Write(f, binary.LittleEndian, uint32(byteRate))
|
||||
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
|
||||
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
|
||||
|
||||
// Write the data subchunk header
|
||||
f.Write([]byte("data"))
|
||||
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
|
||||
}
|
||||
|
||||
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
|
||||
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
|
||||
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
|
||||
|
||||
// Seek to the start of the file and write updated sizes
|
||||
f.Seek(4, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(fileSize))
|
||||
|
||||
f.Seek(40, 0)
|
||||
binary.Write(f, binary.LittleEndian, uint32(dataSize))
|
||||
}
|
||||
@@ -209,15 +209,15 @@ install_cuda_driver_yum() {
|
||||
case $PACKAGE_MANAGER in
|
||||
yum)
|
||||
$SUDO $PACKAGE_MANAGER -y install yum-utils
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
||||
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||
else
|
||||
error $CUDA_REPO_ERR_MSG
|
||||
fi
|
||||
;;
|
||||
dnf)
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then
|
||||
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo" >/dev/null ; then
|
||||
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-$1$2.repo
|
||||
else
|
||||
error $CUDA_REPO_ERR_MSG
|
||||
fi
|
||||
@@ -245,8 +245,8 @@ install_cuda_driver_yum() {
|
||||
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
|
||||
install_cuda_driver_apt() {
|
||||
status 'Installing NVIDIA repository...'
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
||||
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
|
||||
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb" >/dev/null ; then
|
||||
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m | sed -e 's/aarch64/sbsa/')/cuda-keyring_1.1-1_all.deb
|
||||
else
|
||||
error $CUDA_REPO_ERR_MSG
|
||||
fi
|
||||
|
||||
@@ -216,9 +216,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
if err := setSparse(file); err != nil {
|
||||
return err
|
||||
}
|
||||
setSparse(file)
|
||||
|
||||
_ = file.Truncate(b.Total)
|
||||
|
||||
@@ -235,7 +233,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
|
||||
|
||||
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
|
||||
if len(via) > 10 {
|
||||
return errors.New("maxium redirects exceeded (10) for directURL")
|
||||
return errors.New("maximum redirects exceeded (10) for directURL")
|
||||
}
|
||||
|
||||
// if the hostname is the same, allow the redirect
|
||||
|
||||
@@ -373,7 +373,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
var messages []*api.Message
|
||||
parameters := make(map[string]any)
|
||||
|
||||
var layers []*Layer
|
||||
var layers []Layer
|
||||
for _, c := range modelfile.Commands {
|
||||
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
|
||||
|
||||
@@ -499,7 +499,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
|
||||
if c.Name != "license" {
|
||||
// replace
|
||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
@@ -545,7 +545,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
}
|
||||
|
||||
var err2 error
|
||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||
layers = slices.DeleteFunc(layers, func(layer Layer) bool {
|
||||
switch layer.MediaType {
|
||||
case "application/vnd.ollama.image.message":
|
||||
// if there are new messages, remove the inherited ones
|
||||
@@ -625,12 +625,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
return err
|
||||
}
|
||||
|
||||
layer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
configLayer, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, layer := range append(layers, layer) {
|
||||
for _, layer := range append(layers, configLayer) {
|
||||
if layer.status != "" {
|
||||
fn(api.ProgressResponse{Status: layer.status})
|
||||
}
|
||||
@@ -639,7 +639,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
old, _ := ParseNamedManifest(name)
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, layer, layers); err != nil {
|
||||
if err := WriteManifest(name, configLayer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -839,10 +839,10 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return err
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, &manifest.Config)
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
for _, layer := range layers {
|
||||
@@ -911,10 +911,10 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
return fmt.Errorf("pull model manifest: %s", err)
|
||||
}
|
||||
|
||||
var layers []*Layer
|
||||
var layers []Layer
|
||||
layers = append(layers, manifest.Layers...)
|
||||
if manifest.Config.Digest != "" {
|
||||
layers = append(layers, &manifest.Config)
|
||||
layers = append(layers, manifest.Config)
|
||||
}
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
|
||||
@@ -16,15 +16,15 @@ type Layer struct {
|
||||
status string
|
||||
}
|
||||
|
||||
func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
func NewLayer(r io.Reader, mediatype string) (Layer, error) {
|
||||
blobs, err := GetBlobsPath("")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
temp, err := os.CreateTemp(blobs, "sha256-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
defer temp.Close()
|
||||
defer os.Remove(temp.Name())
|
||||
@@ -32,28 +32,28 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
sha256sum := sha256.New()
|
||||
n, err := io.Copy(io.MultiWriter(temp, sha256sum), r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
if err := temp.Close(); err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", sha256sum.Sum(nil))
|
||||
blob, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
status := "using existing layer"
|
||||
if _, err := os.Stat(blob); err != nil {
|
||||
status = "creating new layer"
|
||||
if err := os.Rename(temp.Name(), blob); err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
}
|
||||
|
||||
return &Layer{
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: n,
|
||||
@@ -61,22 +61,22 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
|
||||
func NewLayerFromLayer(digest, mediatype, from string) (Layer, error) {
|
||||
if digest == "" {
|
||||
return nil, errors.New("creating new layer from layer with empty digest")
|
||||
return Layer{}, errors.New("creating new layer from layer with empty digest")
|
||||
}
|
||||
|
||||
blob, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
fi, err := os.Stat(blob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Layer{}, err
|
||||
}
|
||||
|
||||
return &Layer{
|
||||
return Layer{
|
||||
MediaType: mediatype,
|
||||
Digest: digest,
|
||||
Size: fi.Size(),
|
||||
@@ -109,7 +109,7 @@ func (l *Layer) Remove() error {
|
||||
}
|
||||
|
||||
for _, m := range ms {
|
||||
for _, layer := range append(m.Layers, &m.Config) {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest == l.Digest {
|
||||
// something is using this layer
|
||||
return nil
|
||||
|
||||
@@ -14,10 +14,10 @@ import (
|
||||
)
|
||||
|
||||
type Manifest struct {
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config Layer `json:"config"`
|
||||
Layers []*Layer `json:"layers"`
|
||||
SchemaVersion int `json:"schemaVersion"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Config Layer `json:"config"`
|
||||
Layers []Layer `json:"layers"`
|
||||
|
||||
filepath string
|
||||
fi os.FileInfo
|
||||
@@ -25,7 +25,7 @@ type Manifest struct {
|
||||
}
|
||||
|
||||
func (m *Manifest) Size() (size int64) {
|
||||
for _, layer := range append(m.Layers, &m.Config) {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
size += layer.Size
|
||||
}
|
||||
|
||||
@@ -46,7 +46,7 @@ func (m *Manifest) Remove() error {
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, &m.Config) {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if layer.Digest != "" {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
@@ -95,7 +95,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||
func WriteManifest(name model.Name, config Layer, layers []Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -115,7 +115,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||
m := Manifest{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: *config,
|
||||
Config: config,
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
*Layer
|
||||
Layer
|
||||
*llm.GGML
|
||||
}
|
||||
|
||||
|
||||
304
server/routes.go
304
server/routes.go
@@ -10,23 +10,20 @@ import (
|
||||
"io"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand"
|
||||
"mime/multipart"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -109,186 +106,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
|
||||
return runner.llama, model, &opts, nil
|
||||
}
|
||||
|
||||
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
|
||||
var modelPath string
|
||||
if speech.Model == "" {
|
||||
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
|
||||
} else {
|
||||
modelPath = speech.Model
|
||||
}
|
||||
|
||||
// default to 5 minutes
|
||||
var sessionDuration time.Duration
|
||||
if speech.KeepAlive != nil {
|
||||
sessionDuration = speech.KeepAlive.Duration
|
||||
} else {
|
||||
sessionDuration = 5 * time.Minute
|
||||
}
|
||||
|
||||
s.sched.whisperMu.Lock()
|
||||
if s.sched.whisperLoaded[modelPath] != nil {
|
||||
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
|
||||
portCh <- *s.sched.whisperLoaded[modelPath]
|
||||
// Renew the expiration time
|
||||
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
||||
s.sched.whisperMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
whisperServer := "/Users/royhan-ollama/.ollama/server"
|
||||
|
||||
// Find an available port for whisper
|
||||
port := 0
|
||||
params := []string{}
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed")
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
|
||||
|
||||
cmd := exec.Command(whisperServer, finalParams...)
|
||||
slog.Info("starting whisper server", "cmd", cmd.String())
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
err := cmd.Start()
|
||||
if err != nil {
|
||||
slog.Error("failed to start whisper server", "error", err)
|
||||
errCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
// Wait for server connection
|
||||
retries := 25
|
||||
var connErr error
|
||||
for range retries {
|
||||
time.Sleep(50 * time.Millisecond)
|
||||
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
|
||||
if err == nil {
|
||||
conn.Close()
|
||||
connErr = nil
|
||||
break
|
||||
}
|
||||
connErr = err
|
||||
}
|
||||
|
||||
if connErr != nil {
|
||||
slog.Error("failed to connect to whisper server", "error", connErr)
|
||||
errCh <- connErr
|
||||
return
|
||||
}
|
||||
|
||||
portCh <- port
|
||||
s.sched.whisperLoaded[modelPath] = &port
|
||||
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
|
||||
|
||||
s.sched.whisperMu.Unlock()
|
||||
|
||||
// Wait for the whisper server to exit
|
||||
defer func() {
|
||||
ticker := time.NewTicker(5 * time.Second)
|
||||
defer ticker.Stop()
|
||||
for range ticker.C {
|
||||
s.sched.whisperMu.Lock()
|
||||
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
|
||||
slog.Info("exiting whisper server")
|
||||
delete(s.sched.whisperLoaded, modelPath)
|
||||
delete(s.sched.whisperExpiresAt, modelPath)
|
||||
s.sched.whisperMu.Unlock()
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
slog.Debug("whisper server stopped")
|
||||
return
|
||||
}
|
||||
s.sched.whisperMu.Unlock()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
|
||||
// Open the file
|
||||
file, err := os.Open(filePath)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
|
||||
return nil, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
// Create a buffer to hold the multipart form data
|
||||
buffer := &bytes.Buffer{}
|
||||
writer := multipart.NewWriter(buffer)
|
||||
|
||||
// Add the file to the multipart form
|
||||
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := io.Copy(part, file); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Add other fields to the form
|
||||
if err := writer.WriteField("temperature", "0.0"); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Close the writer to finalize the multipart form
|
||||
if err := writer.Close(); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
|
||||
|
||||
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
|
||||
|
||||
res, err := http.DefaultClient.Do(serverReq)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var w api.WhisperCompletion
|
||||
if err := json.Unmarshal(body, &w); err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if w.Error != "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": w.Error})
|
||||
return nil, fmt.Errorf(w.Error)
|
||||
}
|
||||
|
||||
return &w, nil
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
@@ -313,40 +130,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
caps = append(caps, CapabilityInsert)
|
||||
}
|
||||
|
||||
if req.Speech != nil {
|
||||
portCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go s.runWhisperServer(c, portCh, errCh, req.Speech)
|
||||
|
||||
var port int
|
||||
|
||||
select {
|
||||
case port = <-portCh:
|
||||
case err := <-errCh:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
|
||||
return
|
||||
}
|
||||
|
||||
w, err := whisperInference(c, req.Speech.Audio, port)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||
return
|
||||
}
|
||||
|
||||
if req.Speech.Transcribe {
|
||||
c.JSON(http.StatusOK, api.GenerateResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
Response: w.Text,
|
||||
Done: true,
|
||||
DoneReason: "transcribe",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
req.Prompt += "\n" + w.Text
|
||||
}
|
||||
|
||||
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
|
||||
if errors.Is(err, errCapabilityCompletion) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
|
||||
@@ -564,6 +347,7 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
var count int
|
||||
for i, s := range input {
|
||||
tokens, err := r.Tokenize(c.Request.Context(), s)
|
||||
if err != nil {
|
||||
@@ -586,25 +370,36 @@ func (s *Server) EmbedHandler(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
count += len(tokens)
|
||||
|
||||
input[i] = s
|
||||
}
|
||||
embeddings, err := r.Embed(c.Request.Context(), input)
|
||||
if err != nil {
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
|
||||
var g errgroup.Group
|
||||
embeddings := make([][]float32, len(input))
|
||||
for i, text := range input {
|
||||
g.Go(func() error {
|
||||
embedding, err := r.Embedding(c.Request.Context(), text)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
embeddings[i] = normalize(embedding)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
for i, e := range embeddings.Embedding {
|
||||
embeddings.Embedding[i] = normalize(e)
|
||||
if err := g.Wait(); err != nil {
|
||||
slog.Error("embedding generation failed", "error", err)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Errorf("failed to generate embeddings: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
resp := api.EmbedResponse{
|
||||
Model: req.Model,
|
||||
Embeddings: embeddings.Embedding,
|
||||
Embeddings: embeddings,
|
||||
TotalDuration: time.Since(checkpointStart),
|
||||
LoadDuration: checkpointLoaded.Sub(checkpointStart),
|
||||
PromptEvalCount: embeddings.PromptEvalCount,
|
||||
PromptEvalCount: count,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -648,21 +443,20 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt})
|
||||
embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
return
|
||||
}
|
||||
|
||||
embedding := make([]float64, len(embeddings.Embedding[0]))
|
||||
|
||||
for i, v := range embeddings.Embedding[0] {
|
||||
embedding[i] = float64(v)
|
||||
var e []float64
|
||||
for _, v := range embedding {
|
||||
e = append(e, float64(v))
|
||||
}
|
||||
|
||||
resp := api.EmbeddingResponse{
|
||||
Embedding: embedding,
|
||||
Embedding: e,
|
||||
}
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
@@ -1249,6 +1043,11 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
|
||||
if addr, err := netip.ParseAddr(host); err == nil {
|
||||
if addr.IsLoopback() || addr.IsPrivate() || addr.IsUnspecified() || isLocalIP(addr) {
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -1280,6 +1079,7 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
config.AllowOrigins = envconfig.Origins()
|
||||
|
||||
r := gin.Default()
|
||||
r.HandleMethodNotAllowed = true
|
||||
r.Use(
|
||||
cors.New(config),
|
||||
allowedHostsMiddleware(s.addr),
|
||||
@@ -1514,37 +1314,6 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
}
|
||||
|
||||
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
|
||||
slog.Info("processing audio")
|
||||
if req == nil {
|
||||
req = &api.WhisperRequest{}
|
||||
}
|
||||
portCh := make(chan int, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go s.runWhisperServer(c, portCh, errCh, req)
|
||||
|
||||
var port int
|
||||
select {
|
||||
case port = <-portCh:
|
||||
case err := <-errCh:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return err
|
||||
}
|
||||
|
||||
// could parallelize this
|
||||
for i, msg := range msgs {
|
||||
if msg.Audio != "" {
|
||||
w, err := whisperInference(c, msg.Audio, port)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
|
||||
return err
|
||||
}
|
||||
msgs[i].Content += "\n" + w.Text
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) ChatHandler(c *gin.Context) {
|
||||
checkpointStart := time.Now()
|
||||
|
||||
@@ -1589,13 +1358,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
|
||||
}
|
||||
|
||||
if req.Speech != nil || req.RunSpeech {
|
||||
if err := processAudio(c, s, msgs, req.Speech); err != nil {
|
||||
slog.Error("failed to process audio", "error", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
|
||||
@@ -98,7 +98,7 @@ func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
|
||||
if err := WriteManifest(n, config, []Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -46,10 +46,6 @@ type Scheduler struct {
|
||||
getGpuFn func() gpu.GpuInfoList
|
||||
getCpuFn func() gpu.GpuInfoList
|
||||
reschedDelay time.Duration
|
||||
|
||||
whisperLoaded map[string]*int
|
||||
whisperExpiresAt map[string]time.Time
|
||||
whisperMu sync.Mutex
|
||||
}
|
||||
|
||||
// Default automatic value for number of models we allow per GPU
|
||||
@@ -67,17 +63,15 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
|
||||
func InitScheduler(ctx context.Context) *Scheduler {
|
||||
maxQueue := envconfig.MaxQueue()
|
||||
sched := &Scheduler{
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan interface{}, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
whisperLoaded: make(map[string]*int),
|
||||
whisperExpiresAt: make(map[string]time.Time),
|
||||
pendingReqCh: make(chan *LlmRequest, maxQueue),
|
||||
finishedReqCh: make(chan *LlmRequest, maxQueue),
|
||||
expiredCh: make(chan *runnerRef, maxQueue),
|
||||
unloadedCh: make(chan interface{}, maxQueue),
|
||||
loaded: make(map[string]*runnerRef),
|
||||
newServerFn: llm.NewLlamaServer,
|
||||
getGpuFn: gpu.GetGPUInfo,
|
||||
getCpuFn: gpu.GetCPUInfo,
|
||||
reschedDelay: 250 * time.Millisecond,
|
||||
}
|
||||
sched.loadFn = sched.load
|
||||
return sched
|
||||
@@ -116,10 +110,6 @@ func (s *Scheduler) Run(ctx context.Context) {
|
||||
go func() {
|
||||
s.processCompleted(ctx)
|
||||
}()
|
||||
|
||||
// go func() {
|
||||
// could clean up whisper servers in init thread
|
||||
// }
|
||||
}
|
||||
|
||||
func (s *Scheduler) processPending(ctx context.Context) {
|
||||
|
||||
@@ -708,8 +708,8 @@ type mockLlm struct {
|
||||
pingResp error
|
||||
waitResp error
|
||||
completionResp error
|
||||
embedResp *llm.EmbedResponse
|
||||
embedRespErr error
|
||||
embeddingResp []float32
|
||||
embeddingRespErr error
|
||||
tokenizeResp []int
|
||||
tokenizeRespErr error
|
||||
detokenizeResp string
|
||||
@@ -727,8 +727,8 @@ func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn
|
||||
return s.completionResp
|
||||
}
|
||||
|
||||
func (s *mockLlm) Embed(ctx context.Context, input []string) (*llm.EmbedResponse, error) {
|
||||
return s.embedResp, s.embedRespErr
|
||||
func (s *mockLlm) Embedding(ctx context.Context, input string) ([]float32, error) {
|
||||
return s.embeddingResp, s.embeddingRespErr
|
||||
}
|
||||
|
||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
|
||||
@@ -4,6 +4,5 @@ package server
|
||||
|
||||
import "os"
|
||||
|
||||
func setSparse(file *os.File) error {
|
||||
return nil
|
||||
func setSparse(*os.File) {
|
||||
}
|
||||
|
||||
@@ -6,8 +6,9 @@ import (
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func setSparse(file *os.File) error {
|
||||
return windows.DeviceIoControl(
|
||||
func setSparse(file *os.File) {
|
||||
// exFat (and other FS types) don't support sparse files, so ignore errors
|
||||
windows.DeviceIoControl( //nolint:errcheck
|
||||
windows.Handle(file.Fd()), windows.FSCTL_SET_SPARSE,
|
||||
nil, 0,
|
||||
nil, 0,
|
||||
|
||||
@@ -26,7 +26,7 @@ import (
|
||||
var blobUploadManager sync.Map
|
||||
|
||||
type blobUpload struct {
|
||||
*Layer
|
||||
Layer
|
||||
|
||||
Total int64
|
||||
Completed atomic.Int64
|
||||
@@ -362,7 +362,7 @@ func (p *progressWriter) Rollback() {
|
||||
p.written = 0
|
||||
}
|
||||
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
func uploadBlob(ctx context.Context, mp ModelPath, layer Layer, opts *registryOptions, fn func(api.ProgressResponse)) error {
|
||||
requestURL := mp.BaseURL()
|
||||
requestURL = requestURL.JoinPath("v2", mp.GetNamespaceRepository(), "blobs", layer.Digest)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user