Compare commits

..

16 Commits

Author SHA1 Message Date
Roy Han
30823ec925 update readme 2024-08-09 11:32:27 -07:00
Roy Han
89f3bae306 cli 2024-08-09 11:04:26 -07:00
Roy Han
ad7e822883 audio processing error prop 2024-08-07 14:05:22 -07:00
Roy Han
d503f04b32 expiration 2024-08-07 13:04:57 -07:00
Roy Han
8ccf543c53 chat doc 2024-08-07 13:04:57 -07:00
Roy Han
75ad6309b4 chat support 2024-08-07 13:04:57 -07:00
Roy Han
a5181a8c51 error handling 2024-08-07 13:04:57 -07:00
Roy Han
2a9feb0707 model flexibility 2024-08-07 13:04:57 -07:00
Roy Han
e4d35198a2 transcribe 2024-08-07 13:04:57 -07:00
Roy Han
17f9dc6d08 save whisper port 2024-08-07 13:04:57 -07:00
Roy Han
97d9dffa80 err check 2024-08-07 13:04:57 -07:00
Roy Han
65483180b9 working poc 2024-08-07 13:04:57 -07:00
Roy Han
1ac92eae7c submodule 2024-08-07 13:04:57 -07:00
Jesse Gross
69eb06c40e Merge pull request #6145 from ollama/jessegross/bug5840
Fix crash on startup when trying to clean up unused files (#5840)
2024-08-07 11:24:15 -07:00
Jesse Gross
1829fb61bd manifest: Fix crash on startup when trying to clean up unused files (#5840)
Currently if the config field is missing in the manifest file (or
corrupted), Ollama will crash when it tries to read it. This can
happen at startup or when pulling new models.

This data is mostly just used for showing model information so we
can be tolerant of it not being present - it is not required to
run the models. Besides avoiding crashing, this also gives us the
ability to restructure the config in the future by pulling it
into the main manifest file.
2024-08-07 10:30:44 -07:00
Jesse Gross
685a53534b manifest: Don't prune layers if we can't open a manifest file
If there is an error when opening a manifest file (corrupted, permission denied, etc.)
then the referenced layers will not be included in the list of active
layers. This causes them to be deleted when pruning happens at startup
or a model is pulled.

In such a situation, we should prefer to preserve data in the hopes that
it can be recovered rather than being agressive about deletion.
2024-08-06 23:11:19 -07:00
14 changed files with 661 additions and 200 deletions

5
.gitmodules vendored
View File

@@ -1,4 +1,7 @@
[submodule "llama.cpp"]
path = llm/llama.cpp
url = https://github.com/ggerganov/llama.cpp.git
shallow = true
shallow = true
[submodule "llm/whisper.cpp"]
path = llm/whisper.cpp
url = git@github.com:ggerganov/whisper.cpp.git

View File

@@ -36,6 +36,13 @@ func (e StatusError) Error() string {
// ImageData represents the raw binary data of an image file.
type ImageData []byte
type WhisperRequest struct {
Model string `json:"model,omitempty"`
Audio string `json:"audio,omitempty"`
Transcribe bool `json:"transcribe,omitempty"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
}
// GenerateRequest describes a request sent by [Client.Generate]. While you
// have to specify the Model and Prompt fields, all the other fields have
// reasonable defaults for basic uses.
@@ -80,6 +87,8 @@ type GenerateRequest struct {
// Options lists model-specific options. For example, temperature can be
// set through this field, if the model supports it.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -105,6 +114,10 @@ type ChatRequest struct {
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
Speech *WhisperRequest `json:"speech,omitempty"`
RunSpeech bool `json:"run_speech,omitempty"`
}
type Tools []Tool
@@ -127,6 +140,7 @@ type Message struct {
Content string `json:"content"`
Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
Audio string `json:"audio,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
@@ -450,6 +464,11 @@ type GenerateResponse struct {
Metrics
}
type WhisperCompletion struct {
Text string `json:"text"`
Error string `json:"error,omitempty"`
}
// ModelDetails provides details about a model.
type ModelDetails struct {
ParentModel string `json:"parent_model"`

View File

@@ -38,6 +38,7 @@ import (
"github.com/ollama/ollama/format"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/server"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -380,6 +381,14 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
if speech {
return generateInteractiveAudio(cmd, opts)
}
return generateInteractive(cmd, opts)
}
return generate(cmd, opts)
@@ -862,6 +871,7 @@ type runOptions struct {
Options map[string]interface{}
MultiModal bool
KeepAlive *api.Duration
Audio bool
}
type displayResponseState struct {
@@ -970,6 +980,10 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
req.KeepAlive = opts.KeepAlive
}
if opts.Audio {
req.RunSpeech = true
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
@@ -1055,6 +1069,30 @@ func generate(cmd *cobra.Command, opts runOptions) error {
KeepAlive: opts.KeepAlive,
}
speech, err := cmd.Flags().GetBool("speech")
if err != nil {
return err
}
// create temp wav file with the recorder package
if speech {
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
fmt.Print("Speech Mode\n\n")
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
request.Speech = &api.WhisperRequest{
Audio: tempFile.Name(),
}
}
if err := client.Generate(ctx, &request, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil
@@ -1262,6 +1300,7 @@ func NewCLI() *cobra.Command {
RunE: RunHandler,
}
runCmd.Flags().Bool("speech", false, "Speech to text mode")
runCmd.Flags().String("keepalive", "", "Duration to keep a model loaded (e.g. 5m)")
runCmd.Flags().Bool("verbose", false, "Show timings for response")
runCmd.Flags().Bool("insecure", false, "Use an insecure registry")

View File

@@ -20,6 +20,7 @@ import (
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline"
"github.com/ollama/ollama/recorder"
"github.com/ollama/ollama/types/errtypes"
)
@@ -51,6 +52,40 @@ func loadModel(cmd *cobra.Command, opts *runOptions) error {
return client.Chat(cmd.Context(), chatReq, func(api.ChatResponse) error { return nil })
}
func generateInteractiveAudio(cmd *cobra.Command, opts runOptions) error {
for {
p := progress.NewProgress(os.Stderr)
spinner := progress.NewSpinner("")
p.Add("", spinner)
// create temp wav file with the recorder package
tempFile, err := os.CreateTemp("", "recording-*.wav")
if err != nil {
return err
}
defer os.Remove(tempFile.Name())
err = recorder.RecordAudio(tempFile)
if err != nil {
return err
}
p.StopAndClear()
newMessage := api.Message{Role: "user", Audio: tempFile.Name()}
opts.Audio = true
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
}
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")

83
docs/speech.md Normal file
View File

@@ -0,0 +1,83 @@
# Speech to Text Prototype
### To run
`make {/path/to/whisper.cpp/server}`
- replace `whisperServer` in `routes.go` with path to server
## CLI
`./ollama run llama3 [PROMPT] --speech`
- processes voice audio with the provided prompt
`./ollama run llama3 --speech`
- enters interactive mode for continuous voice chat
- TODO: fix exiting interactive mode
Notes: uses default model
## api/generate
### Request fields
- `speech` (required):
- `audio` (required): path to audio file
- `model` (optional): path to whisper model, uses default if null
- `transcribe` (optional): if true, will transcribe and return the audio file
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `prompt` (optional): if not null, passed in with the transcribed audio
#### Transcription
```
curl http://localhost:11434/api/generate -d '{
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"transcribe": true,
"keep_alive": "1m"
},
"stream": false
}' | jq
```
#### Response Generation
```
curl http://localhost:11434/api/generate -d '{
"model": "llama3",
"prompt": "What do you think about this quote?",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav",
"keep_alive": "1m"
},
"stream": false
}' | jq
```
## api/chat
### Request fields
- `model` (required): language model to chat with
- `speech` (optional):
- `model` (optional): path to whisper model, uses default if null
- `keep_alive`: (optional): sets how long the model is stored in memory (default: `5m`)
- `run_speech` (optional): either this flag must be true or `speech` must be passed in for speech mode to run
- `messages`/`message`/`audio` (required): path to audio file
```
curl http://localhost:11434/api/chat -d '{
"model": "llama3",
"speech": {
"model": "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin",
"keep_alive": "10m"
},
"messages": [
{
"role": "system",
"content": "You are a Canadian Nationalist"
},
{
"role": "user",
"content": "What do you think about this quote?",
"audio": "/Users/royhan-ollama/ollama/llm/whisper.cpp/samples/jfk.wav"
}
],
"stream": false
}' | jq
```

1
go.mod
View File

@@ -19,6 +19,7 @@ require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/google/go-cmp v0.6.0
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5
github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0
github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c

2
go.sum
View File

@@ -115,6 +115,8 @@ github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y=
github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
github.com/grpc-ecosystem/grpc-gateway v1.16.0/go.mod h1:BDjrQk3hbvj6Nolgz8mAMFbcEtjT1g+wF4CSlocrBnw=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=

1
llm/whisper.cpp Submodule

Submodule llm/whisper.cpp added at 6739eb83c3

137
recorder/recorder.go Normal file
View File

@@ -0,0 +1,137 @@
package recorder
import (
"encoding/binary"
"fmt"
"os"
"os/signal"
"syscall"
"golang.org/x/sys/unix"
"golang.org/x/term"
"github.com/gordonklaus/portaudio"
)
const (
sampleRate = 16000
numChannels = 1
bitsPerSample = 16
)
func RecordAudio(f *os.File) error {
fmt.Print("Recording. Press any key to stop.\n\n")
sig := make(chan os.Signal, 1)
signal.Notify(sig, os.Interrupt, syscall.SIGTERM)
portaudio.Initialize()
defer portaudio.Terminate()
in := make([]int16, 64)
stream, err := portaudio.OpenDefaultStream(numChannels, 0, sampleRate, len(in), in)
if err != nil {
return err
}
defer stream.Close()
err = stream.Start()
if err != nil {
return err
}
// Write WAV header with placeholder sizes
writeWavHeader(f, sampleRate, numChannels, bitsPerSample)
var totalSamples uint32
// Set up terminal input reading
oldState, err := term.MakeRaw(int(os.Stdin.Fd()))
if err != nil {
return err
}
defer term.Restore(int(os.Stdin.Fd()), oldState)
// Create a channel to handle the stop signal
stop := make(chan struct{})
go func() {
_, err := unix.Read(int(os.Stdin.Fd()), make([]byte, 1))
if err != nil {
fmt.Println("Error reading from stdin:", err)
return
}
// Send signal to stop recording
stop <- struct{}{}
}()
loop:
for {
err = stream.Read()
if err != nil {
return err
}
err = binary.Write(f, binary.LittleEndian, in)
if err != nil {
return err
}
totalSamples += uint32(len(in))
select {
case <-stop:
break loop
case <-sig:
break loop
default:
}
}
err = stream.Stop()
if err != nil {
return err
}
// Update WAV header with actual sizes
updateWavHeader(f, totalSamples, numChannels, bitsPerSample)
return nil
}
func writeWavHeader(f *os.File, sampleRate int, numChannels int, bitsPerSample int) {
subchunk1Size := 16
audioFormat := 1
byteRate := sampleRate * numChannels * (bitsPerSample / 8)
blockAlign := numChannels * (bitsPerSample / 8)
// Write the RIFF header
f.Write([]byte("RIFF"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for file size
f.Write([]byte("WAVE"))
// Write the fmt subchunk
f.Write([]byte("fmt "))
binary.Write(f, binary.LittleEndian, uint32(subchunk1Size))
binary.Write(f, binary.LittleEndian, uint16(audioFormat))
binary.Write(f, binary.LittleEndian, uint16(numChannels))
binary.Write(f, binary.LittleEndian, uint32(sampleRate))
binary.Write(f, binary.LittleEndian, uint32(byteRate))
binary.Write(f, binary.LittleEndian, uint16(blockAlign))
binary.Write(f, binary.LittleEndian, uint16(bitsPerSample))
// Write the data subchunk header
f.Write([]byte("data"))
binary.Write(f, binary.LittleEndian, uint32(0)) // Placeholder for data size
}
func updateWavHeader(f *os.File, totalSamples uint32, numChannels int, bitsPerSample int) {
fileSize := 36 + (totalSamples * uint32(numChannels) * uint32(bitsPerSample/8))
dataSize := totalSamples * uint32(numChannels) * uint32(bitsPerSample/8)
// Seek to the start of the file and write updated sizes
f.Seek(4, 0)
binary.Write(f, binary.LittleEndian, uint32(fileSize))
f.Seek(40, 0)
binary.Write(f, binary.LittleEndian, uint32(dataSize))
}

View File

@@ -250,19 +250,21 @@ func GetModel(name string) (*Model, error) {
Template: template.DefaultTemplate,
}
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
if manifest.Config.Digest != "" {
filename, err := GetBlobsPath(manifest.Config.Digest)
if err != nil {
return nil, err
}
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
configFile, err := os.Open(filename)
if err != nil {
return nil, err
}
defer configFile.Close()
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
if err := json.NewDecoder(configFile).Decode(&model.Config); err != nil {
return nil, err
}
}
for _, layer := range manifest.Layers {
@@ -714,8 +716,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
//nolint:nilerr
return nil
return err
}
for _, layer := range manifest.Layers {
@@ -782,7 +783,8 @@ func PruneLayers() error {
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
return err
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
return nil
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
@@ -839,7 +841,9 @@ func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
if manifest.Config.Digest != "" {
layers = append(layers, &manifest.Config)
}
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
@@ -890,7 +894,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
for _, l := range manifest.Layers {
deleteMap[l.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
if manifest.Config.Digest != "" {
deleteMap[manifest.Config.Digest] = struct{}{}
}
}
}
@@ -907,7 +913,9 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
var layers []*Layer
layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config)
if manifest.Config.Digest != "" {
layers = append(layers, &manifest.Config)
}
skipVerify := make(map[string]bool)
for _, layer := range layers {
@@ -971,7 +979,8 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
fn(api.ProgressResponse{Status: "removing any unused layers"})
err = deleteUnusedLayers(nil, deleteMap)
if err != nil {
return err
slog.Error(fmt.Sprintf("couldn't remove unused layers: %v", err))
fn(api.ProgressResponse{Status: fmt.Sprintf("couldn't remove unused layers: %v", err)})
}
}

View File

@@ -2,6 +2,7 @@ package server
import (
"crypto/sha256"
"errors"
"fmt"
"io"
"os"
@@ -61,6 +62,10 @@ func NewLayer(r io.Reader, mediatype string) (*Layer, error) {
}
func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
if digest == "" {
return nil, errors.New("creating new layer from layer with empty digest")
}
blob, err := GetBlobsPath(digest)
if err != nil {
return nil, err
@@ -81,6 +86,10 @@ func NewLayerFromLayer(digest, mediatype, from string) (*Layer, error) {
}
func (l *Layer) Open() (io.ReadSeekCloser, error) {
if l.Digest == "" {
return nil, errors.New("opening layer with empty digest")
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return nil, err
@@ -90,13 +99,17 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
}
func (l *Layer) Remove() error {
if l.Digest == "" {
return nil
}
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
for _, layer := range append(m.Layers, &m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil

View File

@@ -16,17 +16,16 @@ import (
type Manifest struct {
SchemaVersion int `json:"schemaVersion"`
MediaType string `json:"mediaType"`
Config *Layer `json:"config"`
Config Layer `json:"config"`
Layers []*Layer `json:"layers"`
name model.Name
filepath string
fi os.FileInfo
digest string
}
func (m *Manifest) Size() (size int64) {
for _, layer := range append(m.Layers, m.Config) {
for _, layer := range append(m.Layers, &m.Config) {
size += layer.Size
}
@@ -47,11 +46,13 @@ func (m *Manifest) Remove() error {
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
for _, layer := range append(m.Layers, &m.Config) {
if layer.Digest != "" {
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
slog.Debug("layer does not exist", "digest", layer.Digest)
} else if err != nil {
return err
}
}
}
@@ -70,6 +71,7 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
p := filepath.Join(manifests, n.Filepath())
var m Manifest
f, err := os.Open(p)
if err != nil {
return nil, err
@@ -81,13 +83,11 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err
}
var m Manifest
sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil {
return nil, err
}
m.name = n
m.filepath = p
m.fi = fi
m.digest = hex.EncodeToString(sha256sum.Sum(nil))
@@ -115,7 +115,7 @@ func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
m := Manifest{
SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config,
Config: *config,
Layers: layers,
}

View File

@@ -10,13 +10,17 @@ import (
"io"
"log/slog"
"math"
"math/rand"
"mime/multipart"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"os/signal"
"path/filepath"
"slices"
"strconv"
"strings"
"syscall"
"time"
@@ -105,6 +109,186 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
return runner.llama, model, &opts, nil
}
func (s *Server) runWhisperServer(c *gin.Context, portCh chan int, errCh chan error, speech *api.WhisperRequest) {
var modelPath string
if speech.Model == "" {
modelPath = "/Users/royhan-ollama/.ollama/whisper/ggml-base.en.bin"
} else {
modelPath = speech.Model
}
// default to 5 minutes
var sessionDuration time.Duration
if speech.KeepAlive != nil {
sessionDuration = speech.KeepAlive.Duration
} else {
sessionDuration = 5 * time.Minute
}
s.sched.whisperMu.Lock()
if s.sched.whisperLoaded[modelPath] != nil {
slog.Info(fmt.Sprintf("whisper server already running %s on port %d", modelPath, *s.sched.whisperLoaded[modelPath]))
portCh <- *s.sched.whisperLoaded[modelPath]
// Renew the expiration time
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
return
}
whisperServer := "/Users/royhan-ollama/.ollama/server"
// Find an available port for whisper
port := 0
params := []string{}
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
var l *net.TCPListener
if l, err = net.ListenTCP("tcp", a); err == nil {
port = l.Addr().(*net.TCPAddr).Port
l.Close()
}
}
if port == 0 {
slog.Debug("ResolveTCPAddr failed")
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
}
finalParams := append(params, "--port", strconv.Itoa(port), "--model", modelPath)
cmd := exec.Command(whisperServer, finalParams...)
slog.Info("starting whisper server", "cmd", cmd.String())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err := cmd.Start()
if err != nil {
slog.Error("failed to start whisper server", "error", err)
errCh <- err
return
}
// Wait for server connection
retries := 25
var connErr error
for range retries {
time.Sleep(50 * time.Millisecond)
conn, err := net.DialTimeout("tcp", fmt.Sprintf("localhost:%d", port), time.Second)
if err == nil {
conn.Close()
connErr = nil
break
}
connErr = err
}
if connErr != nil {
slog.Error("failed to connect to whisper server", "error", connErr)
errCh <- connErr
return
}
portCh <- port
s.sched.whisperLoaded[modelPath] = &port
s.sched.whisperExpiresAt[modelPath] = time.Now().Add(sessionDuration)
s.sched.whisperMu.Unlock()
// Wait for the whisper server to exit
defer func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for range ticker.C {
s.sched.whisperMu.Lock()
if time.Now().After(s.sched.whisperExpiresAt[modelPath]) {
slog.Info("exiting whisper server")
delete(s.sched.whisperLoaded, modelPath)
delete(s.sched.whisperExpiresAt, modelPath)
s.sched.whisperMu.Unlock()
if err := cmd.Process.Kill(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
slog.Debug("whisper server stopped")
return
}
s.sched.whisperMu.Unlock()
}
}()
}
func whisperInference(c *gin.Context, filePath string, port int) (*api.WhisperCompletion, error) {
// Open the file
file, err := os.Open(filePath)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to open file"})
return nil, err
}
defer file.Close()
// Create a buffer to hold the multipart form data
buffer := &bytes.Buffer{}
writer := multipart.NewWriter(buffer)
// Add the file to the multipart form
part, err := writer.CreateFormFile("file", filepath.Base(filePath))
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create form file"})
return nil, err
}
if _, err := io.Copy(part, file); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to copy file"})
return nil, err
}
// Add other fields to the form
if err := writer.WriteField("temperature", "0.0"); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to write field"})
return nil, err
}
// Close the writer to finalize the multipart form
if err := writer.Close(); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to close writer"})
return nil, err
}
endpoint := fmt.Sprintf("http://localhost:%s/inference", strconv.Itoa(port))
serverReq, err := http.NewRequestWithContext(c.Request.Context(), http.MethodPost, endpoint, buffer)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to create request"})
return nil, err
}
serverReq.Header.Set("Content-Type", writer.FormDataContentType())
res, err := http.DefaultClient.Do(serverReq)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to send request"})
return nil, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to read response"})
return nil, err
}
var w api.WhisperCompletion
if err := json.Unmarshal(body, &w); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to unmarshal response"})
return nil, err
}
if w.Error != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": w.Error})
return nil, fmt.Errorf(w.Error)
}
return &w, nil
}
func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest
@@ -129,6 +313,40 @@ func (s *Server) GenerateHandler(c *gin.Context) {
caps = append(caps, CapabilityInsert)
}
if req.Speech != nil {
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req.Speech)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err})
return
}
w, err := whisperInference(c, req.Speech.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return
}
if req.Speech.Transcribe {
c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model,
CreatedAt: time.Now().UTC(),
Response: w.Text,
Done: true,
DoneReason: "transcribe",
})
return
}
req.Prompt += "\n" + w.Text
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -703,153 +921,6 @@ func (s *Server) ShowModelHandler(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
func manifestLayers(m *Manifest, exclude []string) (map[string]any, error) {
r := map[string]any{
"name": m.name.DisplayShortest(),
"digest": m.digest,
"size": m.Size(),
"modified_at": m.fi.ModTime(),
}
excludeAll := slices.Contains(exclude, "all")
excludeDetails := slices.Contains(exclude, "details")
for _, layer := range m.Layers {
var errExcludeKey = errors.New("exclude key")
key, content, err := func() (string, any, error) {
key := strings.TrimPrefix(layer.MediaType, "application/vnd.ollama.image.")
if slices.Contains(exclude, key) || excludeAll {
return "", nil, errExcludeKey
}
f, err := layer.Open()
if err != nil {
return "", nil, err
}
defer f.Close()
switch key {
case "model", "projector", "adapter":
ggml, _, err := llm.DecodeGGML(f, 0)
if err != nil {
return "", nil, err
}
content := map[string]any{
"architecture": ggml.KV().Architecture(),
"file_type": ggml.KV().FileType().String(),
"parameter_count": ggml.KV().ParameterCount(),
}
if !slices.Contains(exclude, key+".details") && !excludeAll && !excludeDetails {
// exclude any extraneous or redundant fields
delete(ggml.KV(), "general.basename")
delete(ggml.KV(), "general.description")
delete(ggml.KV(), "general.filename")
delete(ggml.KV(), "general.finetune")
delete(ggml.KV(), "general.languages")
delete(ggml.KV(), "general.license")
delete(ggml.KV(), "general.license.link")
delete(ggml.KV(), "general.name")
delete(ggml.KV(), "general.paramter_count")
delete(ggml.KV(), "general.size_label")
delete(ggml.KV(), "general.tags")
delete(ggml.KV(), "general.type")
delete(ggml.KV(), "general.quantization_version")
delete(ggml.KV(), "tokenizer.chat_template")
content["details"] = ggml.KV()
}
return key, content, nil
case "params", "messages":
var content any
if err := json.NewDecoder(f).Decode(&content); err != nil {
return "", nil, err
}
return key, content, nil
case "template", "system", "license":
bts, err := io.ReadAll(f)
if err != nil {
return "", nil, err
}
if key == "license" {
return key, []any{string(bts)}, nil
}
return key, string(bts), nil
}
return layer.MediaType, nil, nil
}()
if errors.Is(err, errExcludeKey) {
continue
} else if err != nil {
return nil, err
}
if s, ok := r[key].([]any); ok {
r[key] = append(s, content)
} else {
r[key] = content
}
}
return r, nil
}
func (s *Server) GetModelsHandler(c *gin.Context) {
ms, err := Manifests()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var rs []map[string]any
for _, m := range ms {
r, err := manifestLayers(m, c.QueryArray("exclude"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
rs = append(rs, r)
}
slices.SortStableFunc(rs, func(i, j map[string]any) int {
// most recently modified first
return cmp.Compare(
j["modified_at"].(time.Time).Unix(),
i["modified_at"].(time.Time).Unix(),
)
})
c.JSON(http.StatusOK, rs)
}
func (s *Server) GetModelHandler(c *gin.Context) {
n := model.ParseName(strings.TrimPrefix(c.Param("model"), "/"))
if !n.IsValid() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid model name"})
return
}
m, err := ParseNamedManifest(n)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
r, err := manifestLayers(m, c.QueryArray("exclude"))
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, r)
}
func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m, err := GetModel(req.Model)
if err != nil {
@@ -971,17 +1042,20 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
models := []api.ListModelResponse{}
for n, m := range ms {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
var cf ConfigV2
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
if m.Config.Digest != "" {
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest filepath", "name", n, "error", err)
continue
}
defer f.Close()
if err := json.NewDecoder(f).Decode(&cf); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
continue
}
}
// tag should never be masked
@@ -1237,9 +1311,6 @@ func (s *Server) GenerateRoutes() http.Handler {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/models", s.GetModelsHandler)
r.Handle(method, "/api/models/*model", s.GetModelHandler)
r.Handle(method, "/api/tags", s.ListModelsHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
@@ -1443,6 +1514,37 @@ func (s *Server) ProcessHandler(c *gin.Context) {
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
}
func processAudio(c *gin.Context, s *Server, msgs []api.Message, req *api.WhisperRequest) error {
slog.Info("processing audio")
if req == nil {
req = &api.WhisperRequest{}
}
portCh := make(chan int, 1)
errCh := make(chan error, 1)
go s.runWhisperServer(c, portCh, errCh, req)
var port int
select {
case port = <-portCh:
case err := <-errCh:
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return err
}
// could parallelize this
for i, msg := range msgs {
if msg.Audio != "" {
w, err := whisperInference(c, msg.Audio, port)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "failed to generate completion"})
return err
}
msgs[i].Content += "\n" + w.Text
}
}
return nil
}
func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
@@ -1487,6 +1589,13 @@ func (s *Server) ChatHandler(c *gin.Context) {
msgs = append([]api.Message{{Role: "system", Content: m.System}}, msgs...)
}
if req.Speech != nil || req.RunSpeech {
if err := processAudio(c, s, msgs, req.Speech); err != nil {
slog.Error("failed to process audio", "error", err)
return
}
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, msgs, req.Tools)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})

View File

@@ -46,6 +46,10 @@ type Scheduler struct {
getGpuFn func() gpu.GpuInfoList
getCpuFn func() gpu.GpuInfoList
reschedDelay time.Duration
whisperLoaded map[string]*int
whisperExpiresAt map[string]time.Time
whisperMu sync.Mutex
}
// Default automatic value for number of models we allow per GPU
@@ -63,15 +67,17 @@ var ErrMaxQueue = errors.New("server busy, please try again. maximum pending re
func InitScheduler(ctx context.Context) *Scheduler {
maxQueue := envconfig.MaxQueue()
sched := &Scheduler{
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
pendingReqCh: make(chan *LlmRequest, maxQueue),
finishedReqCh: make(chan *LlmRequest, maxQueue),
expiredCh: make(chan *runnerRef, maxQueue),
unloadedCh: make(chan interface{}, maxQueue),
loaded: make(map[string]*runnerRef),
newServerFn: llm.NewLlamaServer,
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
whisperLoaded: make(map[string]*int),
whisperExpiresAt: make(map[string]time.Time),
}
sched.loadFn = sched.load
return sched
@@ -110,6 +116,10 @@ func (s *Scheduler) Run(ctx context.Context) {
go func() {
s.processCompleted(ctx)
}()
// go func() {
// could clean up whisper servers in init thread
// }
}
func (s *Scheduler) processPending(ctx context.Context) {