Compare commits

..

57 Commits

Author SHA1 Message Date
Josh Yan
a548eb6003 a8db2a9 2024-07-10 13:10:58 -07:00
Josh Yan
f92818d90d patch again 2024-07-10 13:06:40 -07:00
Josh Yan
1ef59057d0 patch llama.cpp 2024-07-10 13:02:37 -07:00
Josh Yan
106fe6b4ae patch 2024-07-10 12:46:45 -07:00
Josh Yan
5fd359d117 added patch 2024-07-10 12:46:45 -07:00
Josh Yan
b0e4e8d76c change 2024-07-10 12:46:44 -07:00
Josh Yan
e59453982d logs 2024-07-10 12:46:22 -07:00
Josh Yan
369113970a wooh 2024-07-10 12:46:14 -07:00
Josh Yan
26ed829415 test 2024-07-10 12:46:14 -07:00
Josh Yan
542134bf50 new 2024-07-10 12:46:09 -07:00
Josh Yan
9e0b8f1fe2 another change 2024-07-10 12:46:06 -07:00
Josh Yan
c498609ba3 cast 2024-07-10 12:45:47 -07:00
Josh Yan
c800a67f1b cast 2024-07-10 12:45:12 -07:00
Josh Yan
dfc62648f3 cast 2024-07-10 12:45:12 -07:00
Josh Yan
24e8292e94 new changes 2024-07-10 12:45:10 -07:00
Josh Yan
c63b4ecbf7 quantize 2024-07-10 12:44:40 -07:00
Josh Yan
ee2b9b076c stop spinner 2024-07-10 12:40:29 -07:00
Josh Yan
bec9100f32 tensor count 2024-07-10 12:40:29 -07:00
Josh Yan
1344843515 image 2024-07-10 12:40:29 -07:00
Josh Yan
e87eafe5cd quantize percentage 2024-07-10 12:40:29 -07:00
Josh Yan
6bab0e2368 lint 2024-07-10 12:36:32 -07:00
Josh Yan
c4cccaf936 remove rebase err 2024-07-10 11:37:55 -07:00
Josh Yan
9fe5c393e4 hi 2024-07-10 11:35:01 -07:00
Josh Yan
007c988dba rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
91d21e7c7b rmv double msg 2024-07-10 11:34:31 -07:00
Josh Yan
3e64284f69 percent 2024-07-10 11:34:31 -07:00
Josh Yan
39910f2ab2 percent 2024-07-10 11:34:29 -07:00
Josh Yan
96d0cd92f2 rebase 2024-07-10 11:31:53 -07:00
Josh Yan
3a724a7c80 isLocal firstdraft 2024-07-10 11:31:12 -07:00
Josh Yan
f520f0056e rm config 2024-07-10 11:29:51 -07:00
Josh Yan
d25f85ede4 on disk copy 2024-07-10 11:29:49 -07:00
Josh Yan
b48420b74b percent 2024-07-10 11:27:33 -07:00
Josh Yan
784958a1cb transfer data 2024-07-10 11:24:50 -07:00
Josh Yan
ae65cc8dea progress 2024-07-10 11:24:48 -07:00
Josh Yan
a037528bba lint 2024-07-10 11:20:02 -07:00
Josh Yan
04bf41deb5 clean 2024-07-10 11:20:02 -07:00
Josh Yan
c23cec9547 removed cmt and prints 2024-07-10 11:20:02 -07:00
Josh Yan
8377dc48d0 removed client isLocal() 2024-07-10 11:20:02 -07:00
Josh Yan
3aee405dfa lint 2024-07-10 11:20:02 -07:00
Josh Yan
9b3f47b674 lint 2024-07-10 11:20:02 -07:00
Josh Yan
f5441f01a2 lint 2024-07-10 11:20:02 -07:00
Josh Yan
ab165df43a syscopy windows 2024-07-10 11:20:02 -07:00
Josh Yan
79cc4c9585 os copy 2024-07-10 11:20:02 -07:00
Josh Yan
bc3f59a6ad rmv prints 2024-07-10 11:20:02 -07:00
Josh Yan
1a85cb904c local copy 2024-07-10 11:20:02 -07:00
Josh Yan
10ea0987e9 isLocal firstdraft 2024-07-10 11:19:50 -07:00
Josh Yan
413d368a6a clean 2024-07-10 11:19:32 -07:00
Josh Yan
cabf375059 rm bench 2024-07-10 11:19:32 -07:00
Josh Yan
ca0ee1d4fe rm config 2024-07-10 11:19:32 -07:00
Josh Yan
1142999aab rm config 2024-07-10 11:19:32 -07:00
Josh Yan
0d5a72aba9 clean 2024-07-10 11:19:32 -07:00
Josh Yan
ea837412c2 local path 2024-07-10 11:19:32 -07:00
Josh Yan
736ad6f438 still works 2024-07-10 11:19:32 -07:00
Josh Yan
64607d16a5 working 2024-07-10 11:19:32 -07:00
Josh Yan
a6cfe7f00b benchmark 2024-07-10 11:19:32 -07:00
Josh Yan
c3b411a515 on disk copy 2024-07-10 11:19:32 -07:00
Josh Yan
928f37e3ae start tests 2024-07-10 11:19:32 -07:00
141 changed files with 1517 additions and 4726 deletions

View File

@@ -31,7 +31,7 @@ jobs:
security set-keychain-settings -lut 3600 build.keychain security set-keychain-settings -lut 3600 build.keychain
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: Build Darwin - name: Build Darwin
env: env:
@@ -87,7 +87,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@@ -141,13 +141,13 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer" write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP" write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP" write-host "Completed AMD HIP"
@@ -218,7 +218,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@@ -306,7 +306,7 @@ jobs:
write-host "plugin installed" write-host "plugin installed"
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get - run: go get
- uses: actions/download-artifact@v4 - uses: actions/download-artifact@v4

View File

@@ -63,7 +63,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: go get ./... - run: go get ./...
- run: | - run: |
@@ -126,7 +126,7 @@ jobs:
strategy: strategy:
matrix: matrix:
rocm-version: rocm-version:
- '6.1.2' - '6.1.1'
runs-on: linux runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }} container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps: steps:
@@ -163,13 +163,13 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install ROCm' - name: 'Install ROCm'
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer" write-host "downloading AMD HIP Installer"
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP" write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP" write-host "Completed AMD HIP"
@@ -200,7 +200,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- name: 'Install CUDA' - name: 'Install CUDA'
run: | run: |
@@ -255,7 +255,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: false cache: false
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in
@@ -297,7 +297,7 @@ jobs:
submodules: recursive submodules: recursive
- uses: actions/setup-go@v5 - uses: actions/setup-go@v5
with: with:
go-version: "stable" go-version-file: go.mod
cache: true cache: true
- run: | - run: |
case ${{ matrix.arch }} in case ${{ matrix.arch }} in

View File

@@ -1,8 +1,8 @@
ARG GOLANG_VERSION=1.22.5 ARG GOLANG_VERSION=1.22.1
ARG CMAKE_VERSION=3.22.1 ARG CMAKE_VERSION=3.22.1
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md # this CUDA_VERSION corresponds with the one specified in docs/gpu.md
ARG CUDA_VERSION=11.3.1 ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=6.1.2 ARG ROCM_VERSION=6.1.1
# Copy the minimal context we need to run the generate scripts # Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code FROM scratch AS llm-code

View File

@@ -35,10 +35,10 @@ The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `olla
## Quickstart ## Quickstart
To run and chat with [Llama 3.1](https://ollama.com/library/llama3.1): To run and chat with [Llama 3](https://ollama.com/library/llama3):
``` ```
ollama run llama3.1 ollama run llama3
``` ```
## Model library ## Model library
@@ -49,9 +49,8 @@ Here are some example models that can be downloaded:
| Model | Parameters | Size | Download | | Model | Parameters | Size | Download |
| ------------------ | ---------- | ----- | ------------------------------ | | ------------------ | ---------- | ----- | ------------------------------ |
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` | | Llama 3 | 8B | 4.7GB | `ollama run llama3` |
| Llama 3.1 | 70B | 40GB | `ollama run llama3.1:70b` | | Llama 3 | 70B | 40GB | `ollama run llama3:70b` |
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` | | Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
| Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` | | Phi 3 Medium | 14B | 7.9GB | `ollama run phi3:medium` |
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` | | Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
@@ -65,8 +64,7 @@ Here are some example models that can be downloaded:
| LLaVA | 7B | 4.5GB | `ollama run llava` | | LLaVA | 7B | 4.5GB | `ollama run llava` |
| Solar | 10.7B | 6.1GB | `ollama run solar` | | Solar | 10.7B | 6.1GB | `ollama run solar` |
> [!NOTE] > Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
## Customize a model ## Customize a model
@@ -98,16 +96,16 @@ See the [guide](docs/import.md) on importing models for more information.
### Customize a prompt ### Customize a prompt
Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3.1` model: Models from the Ollama library can be customized with a prompt. For example, to customize the `llama3` model:
``` ```
ollama pull llama3.1 ollama pull llama3
``` ```
Create a `Modelfile`: Create a `Modelfile`:
``` ```
FROM llama3.1 FROM llama3
# set the temperature to 1 [higher is more creative, lower is more coherent] # set the temperature to 1 [higher is more creative, lower is more coherent]
PARAMETER temperature 1 PARAMETER temperature 1
@@ -142,7 +140,7 @@ ollama create mymodel -f ./Modelfile
### Pull a model ### Pull a model
``` ```
ollama pull llama3.1 ollama pull llama3
``` ```
> This command can also be used to update a local model. Only the diff will be pulled. > This command can also be used to update a local model. Only the diff will be pulled.
@@ -150,13 +148,13 @@ ollama pull llama3.1
### Remove a model ### Remove a model
``` ```
ollama rm llama3.1 ollama rm llama3
``` ```
### Copy a model ### Copy a model
``` ```
ollama cp llama3.1 my-model ollama cp llama3 my-model
``` ```
### Multiline input ### Multiline input
@@ -180,14 +178,14 @@ The image features a yellow smiley face, which is likely the central focus of th
### Pass the prompt as an argument ### Pass the prompt as an argument
``` ```
$ ollama run llama3.1 "Summarize this file: $(cat README.md)" $ ollama run llama3 "Summarize this file: $(cat README.md)"
Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications. Ollama is a lightweight, extensible framework for building and running language models on the local machine. It provides a simple API for creating, running, and managing models, as well as a library of pre-built models that can be easily used in a variety of applications.
``` ```
### Show model information ### Show model information
``` ```
ollama show llama3.1 ollama show llama3
``` ```
### List models on your computer ### List models on your computer
@@ -215,7 +213,7 @@ Next, start the server:
Finally, in a separate shell, run a model: Finally, in a separate shell, run a model:
``` ```
./ollama run llama3.1 ./ollama run llama3
``` ```
## REST API ## REST API
@@ -226,7 +224,7 @@ Ollama has a REST API for running and managing models.
``` ```
curl http://localhost:11434/api/generate -d '{ curl http://localhost:11434/api/generate -d '{
"model": "llama3.1", "model": "llama3",
"prompt":"Why is the sky blue?" "prompt":"Why is the sky blue?"
}' }'
``` ```
@@ -235,7 +233,7 @@ curl http://localhost:11434/api/generate -d '{
``` ```
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
"model": "llama3.1", "model": "llama3",
"messages": [ "messages": [
{ "role": "user", "content": "why is the sky blue?" } { "role": "user", "content": "why is the sky blue?" }
] ]
@@ -295,10 +293,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama) - [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
- [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama) - [Ollama with Google Mesop](https://github.com/rapidarchitect/ollama_mesop/) (Mesop Chat Client implementation with Ollama)
- [Kerlig AI](https://www.kerlig.com/) (AI writing assistant for macOS)
- [AI Studio](https://github.com/MindWorkAI/AI-Studio)
- [Sidellama](https://github.com/gyopak/sidellama) (browser-based LLM client)
- [LLMStack](https://github.com/trypromptly/LLMStack) (No-code multi-agent framework to build LLM agents and workflows)
### Terminal ### Terminal
@@ -390,7 +384,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot) - [Ollama Copilot](https://github.com/bernardo-bruning/ollama-copilot) (Proxy that allows you to use ollama as a copilot like Github copilot)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama) - [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and Hugging Face) - [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension) - [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend) - [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support) - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)

View File

@@ -17,14 +17,20 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"os"
"path/filepath"
"runtime" "runtime"
"strings"
"time"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -347,16 +353,7 @@ func (c *Client) Heartbeat(ctx context.Context) error {
return nil return nil
} }
// Embed generates embeddings from a model. // Embeddings generates embeddings from a model.
func (c *Client) Embed(ctx context.Context, req *EmbedRequest) (*EmbedResponse, error) {
var resp EmbedResponse
if err := c.do(ctx, http.MethodPost, "/api/embed", req, &resp); err != nil {
return nil, err
}
return &resp, nil
}
// Embeddings generates an embedding from a model.
func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) { func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*EmbeddingResponse, error) {
var resp EmbeddingResponse var resp EmbeddingResponse
if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil { if err := c.do(ctx, http.MethodPost, "/api/embeddings", req, &resp); err != nil {
@@ -383,3 +380,27 @@ func (c *Client) Version(ctx context.Context) (string, error) {
return version.Version, nil return version.Version, nil
} }
func Authorization(ctx context.Context, request *http.Request) (string, error) {
data := []byte(fmt.Sprintf("%s,%s,%d", request.Method, request.URL.RequestURI(), time.Now().Unix()))
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
knownHostsFile, err := os.OpenFile(filepath.Join(home, ".ollama", "known_hosts"), os.O_CREATE|os.O_RDWR|os.O_APPEND, 0600)
if err != nil {
return "", err
}
defer knownHostsFile.Close()
token, err := auth.Sign(ctx, data)
if err != nil {
return "", err
}
// interleave request data into the token
key, sig, _ := strings.Cut(token, ":")
return fmt.Sprintf("%s:%s:%s", key, base64.StdEncoding.EncodeToString(data), sig), nil
}

View File

@@ -47,9 +47,6 @@ type GenerateRequest struct {
// Prompt is the textual prompt to send to the model. // Prompt is the textual prompt to send to the model.
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
// Suffix is the text that comes after the inserted text.
Suffix string `json:"suffix"`
// System overrides the model's default system message/prompt. // System overrides the model's default system message/prompt.
System string `json:"system"` System string `json:"system"`
@@ -100,85 +97,17 @@ type ChatRequest struct {
// followin the request. // followin the request.
KeepAlive *Duration `json:"keep_alive,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
// Tools is an optional list of tools the model has access to.
Tools `json:"tools,omitempty"`
// Options lists model-specific options. // Options lists model-specific options.
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
type Tools []Tool
func (t Tools) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
func (t Tool) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
// Message is a single message in a chat sequence. The message contains the // Message is a single message in a chat sequence. The message contains the
// role ("system", "user", or "assistant"), the content and an optional list // role ("system", "user", or "assistant"), the content and an optional list
// of images. // of images.
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content string `json:"content"` Content string `json:"content"`
Images []ImageData `json:"images,omitempty"` Images []ImageData `json:"images,omitempty"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
}
func (m *Message) UnmarshalJSON(b []byte) error {
type Alias Message
var a Alias
if err := json.Unmarshal(b, &a); err != nil {
return err
}
*m = Message(a)
m.Role = strings.ToLower(m.Role)
return nil
}
type ToolCall struct {
Function ToolCallFunction `json:"function"`
}
type ToolCallFunction struct {
Name string `json:"name"`
Arguments ToolCallFunctionArguments `json:"arguments"`
}
type ToolCallFunctionArguments map[string]any
func (t *ToolCallFunctionArguments) String() string {
bts, _ := json.Marshal(t)
return string(bts)
}
type Tool struct {
Type string `json:"type"`
Function ToolFunction `json:"function"`
}
type ToolFunction struct {
Name string `json:"name"`
Description string `json:"description"`
Parameters struct {
Type string `json:"type"`
Required []string `json:"required"`
Properties map[string]struct {
Type string `json:"type"`
Description string `json:"description"`
Enum []string `json:"enum,omitempty"`
} `json:"properties"`
} `json:"parameters"`
}
func (t *ToolFunction) String() string {
bts, _ := json.Marshal(t)
return string(bts)
} }
// ChatResponse is the response returned by [Client.Chat]. Its fields are // ChatResponse is the response returned by [Client.Chat]. Its fields are
@@ -214,7 +143,6 @@ type Options struct {
NumPredict int `json:"num_predict,omitempty"` NumPredict int `json:"num_predict,omitempty"`
TopK int `json:"top_k,omitempty"` TopK int `json:"top_k,omitempty"`
TopP float32 `json:"top_p,omitempty"` TopP float32 `json:"top_p,omitempty"`
MinP float32 `json:"min_p,omitempty"`
TFSZ float32 `json:"tfs_z,omitempty"` TFSZ float32 `json:"tfs_z,omitempty"`
TypicalP float32 `json:"typical_p,omitempty"` TypicalP float32 `json:"typical_p,omitempty"`
RepeatLastN int `json:"repeat_last_n,omitempty"` RepeatLastN int `json:"repeat_last_n,omitempty"`
@@ -245,30 +173,6 @@ type Runner struct {
NumThread int `json:"num_thread,omitempty"` NumThread int `json:"num_thread,omitempty"`
} }
// EmbedRequest is the request passed to [Client.Embed].
type EmbedRequest struct {
// Model is the model name.
Model string `json:"model"`
// Input is the input to embed.
Input any `json:"input"`
// KeepAlive controls how long the model will stay loaded in memory following
// this request.
KeepAlive *Duration `json:"keep_alive,omitempty"`
Truncate *bool `json:"truncate,omitempty"`
// Options lists model-specific options.
Options map[string]interface{} `json:"options"`
}
// EmbedResponse is the response from [Client.Embed].
type EmbedResponse struct {
Model string `json:"model"`
Embeddings [][]float32 `json:"embeddings"`
}
// EmbeddingRequest is the request passed to [Client.Embeddings]. // EmbeddingRequest is the request passed to [Client.Embeddings].
type EmbeddingRequest struct { type EmbeddingRequest struct {
// Model is the model name. // Model is the model name.
@@ -315,10 +219,8 @@ type DeleteRequest struct {
// ShowRequest is the request passed to [Client.Show]. // ShowRequest is the request passed to [Client.Show].
type ShowRequest struct { type ShowRequest struct {
Model string `json:"model"` Model string `json:"model"`
System string `json:"system"` System string `json:"system"`
// Template is deprecated
Template string `json:"template"` Template string `json:"template"`
Verbose bool `json:"verbose"` Verbose bool `json:"verbose"`
@@ -365,6 +267,7 @@ type PullRequest struct {
type ProgressResponse struct { type ProgressResponse struct {
Status string `json:"status"` Status string `json:"status"`
Digest string `json:"digest,omitempty"` Digest string `json:"digest,omitempty"`
Quantize string `json:"quantize,omitempty"`
Total int64 `json:"total,omitempty"` Total int64 `json:"total,omitempty"`
Completed int64 `json:"completed,omitempty"` Completed int64 `json:"completed,omitempty"`
} }

View File

@@ -208,26 +208,3 @@ func TestUseMmapFormatParams(t *testing.T) {
}) })
} }
} }
func TestMessage_UnmarshalJSON(t *testing.T) {
tests := []struct {
input string
expected string
}{
{`{"role": "USER", "content": "Hello!"}`, "user"},
{`{"role": "System", "content": "Initialization complete."}`, "system"},
{`{"role": "assistant", "content": "How can I help you?"}`, "assistant"},
{`{"role": "TOOl", "content": "Access granted."}`, "tool"},
}
for _, test := range tests {
var msg Message
if err := json.Unmarshal([]byte(test.input), &msg); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if msg.Role != test.expected {
t.Errorf("role not lowercased: got %v, expected %v", msg.Role, test.expected)
}
}
}

View File

@@ -127,10 +127,6 @@ Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\models"
Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history" Type: filesandordirs; Name: "{%USERPROFILE}\.ollama\history"
; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved ; NOTE: if the user has a custom OLLAMA_MODELS it will be preserved
[InstallDelete]
Type: filesandordirs; Name: "{%TEMP}\ollama*"
Type: filesandordirs; Name: "{%LOCALAPPDATA}\Programs\Ollama"
[Messages] [Messages]
WizardReady=Ollama Windows Preview WizardReady=Ollama Windows Preview
ReadyLabel1=%nLet's get you up and running with your own large language models. ReadyLabel1=%nLet's get you up and running with your own large language models.
@@ -138,7 +134,7 @@ SetupAppRunningError=Another Ollama installer is running.%n%nPlease cancel or fi
;FinishedHeadingLabel=Run your first model ;FinishedHeadingLabel=Run your first model
;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3.1 ;FinishedLabel=%nRun this command in a PowerShell or cmd terminal.%n%n%n ollama run llama3
;ClickFinish=%n ;ClickFinish=%n
[Registry] [Registry]

View File

@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host "" write-host ""
write-host "Run your first model:" write-host "Run your first model:"
write-host "" write-host ""
write-host "`tollama run llama3.1" write-host "`tollama run llama3"
write-host "" write-host ""

View File

@@ -10,42 +10,37 @@ import (
"log/slog" "log/slog"
"os" "os"
"path/filepath" "path/filepath"
"strings"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
) )
const defaultPrivateKey = "id_ed25519" const defaultPrivateKey = "id_ed25519"
func keyPath() (string, error) { func keyPath() (ssh.Signer, error) {
home, err := os.UserHomeDir() home, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", err return nil, err
}
return filepath.Join(home, ".ollama", defaultPrivateKey), nil
}
func GetPublicKey() (string, error) {
keyPath, err := keyPath()
if err != nil {
return "", err
} }
keyPath := filepath.Join(home, ".ollama", defaultPrivateKey)
privateKeyFile, err := os.ReadFile(keyPath) privateKeyFile, err := os.ReadFile(keyPath)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err)) slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err return nil, err
} }
privateKey, err := ssh.ParsePrivateKey(privateKeyFile) return ssh.ParsePrivateKey(privateKeyFile)
}
func GetPublicKey() (ssh.PublicKey, error) {
privateKey, err := keyPath()
// if privateKey, try public key directly
if err != nil { if err != nil {
return "", err return nil, err
} }
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) return privateKey.PublicKey(), nil
return strings.TrimSpace(string(publicKey)), nil
} }
func NewNonce(r io.Reader, length int) (string, error) { func NewNonce(r io.Reader, length int) (string, error) {
@@ -58,25 +53,20 @@ func NewNonce(r io.Reader, length int) (string, error) {
} }
func Sign(ctx context.Context, bts []byte) (string, error) { func Sign(ctx context.Context, bts []byte) (string, error) {
keyPath, err := keyPath() privateKey, err := keyPath()
if err != nil {
return "", err
}
privateKeyFile, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
return "", err
}
privateKey, err := ssh.ParsePrivateKey(privateKeyFile)
if err != nil { if err != nil {
return "", err return "", err
} }
// get the pubkey, but remove the type // get the pubkey, but remove the type
publicKey := ssh.MarshalAuthorizedKey(privateKey.PublicKey()) publicKey, err := GetPublicKey()
parts := bytes.Split(publicKey, []byte(" ")) if err != nil {
return "", err
}
publicKeyBytes := ssh.MarshalAuthorizedKey(publicKey)
parts := bytes.Split(publicKeyBytes, []byte(" "))
if len(parts) < 2 { if len(parts) < 2 {
return "", fmt.Errorf("malformed public key") return "", fmt.Errorf("malformed public key")
} }

View File

@@ -7,6 +7,7 @@ import (
"crypto/ed25519" "crypto/ed25519"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"encoding/json"
"encoding/pem" "encoding/pem"
"errors" "errors"
"fmt" "fmt"
@@ -15,6 +16,7 @@ import (
"math" "math"
"net" "net"
"net/http" "net/http"
"net/url"
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
@@ -78,6 +80,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
status := "transferring model data" status := "transferring model data"
spinner := progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
defer p.Stop()
for i := range modelfile.Commands { for i := range modelfile.Commands {
switch modelfile.Commands[i].Name { switch modelfile.Commands[i].Name {
@@ -112,16 +115,17 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
path = tempfile path = tempfile
} }
digest, err := createBlob(cmd, client, path) // spinner.Stop()
digest, err := createBlob(cmd, client, path, spinner)
if err != nil { if err != nil {
return err return err
} }
modelfile.Commands[i].Args = "@" + digest modelfile.Commands[i].Args = "@" + digest
} }
} }
bars := make(map[string]*progress.Bar) bars := make(map[string]*progress.Bar)
var quantizeSpin *progress.Spinner
fn := func(resp api.ProgressResponse) error { fn := func(resp api.ProgressResponse) error {
if resp.Digest != "" { if resp.Digest != "" {
spinner.Stop() spinner.Stop()
@@ -134,11 +138,20 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
} }
bar.Set(resp.Completed) bar.Set(resp.Completed)
} else if resp.Quantize != "" {
spinner.Stop()
if quantizeSpin != nil {
quantizeSpin.SetMessage(resp.Status)
} else {
quantizeSpin = progress.NewSpinner(resp.Status)
p.Add("quantize", quantizeSpin)
}
} else if status != resp.Status { } else if status != resp.Status {
spinner.Stop() spinner.Stop()
status = resp.Status status = resp.Status
spinner = progress.NewSpinner(status) spinner := progress.NewSpinner(status)
p.Add(status, spinner) p.Add(status, spinner)
} }
@@ -263,13 +276,22 @@ func tempZipFiles(path string) (string, error) {
return tempfile.Name(), nil return tempfile.Name(), nil
} }
func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, error) { var ErrBlobExists = errors.New("blob exists")
func createBlob(cmd *cobra.Command, client *api.Client, path string, spinner *progress.Spinner) (string, error) {
bin, err := os.Open(path) bin, err := os.Open(path)
if err != nil { if err != nil {
return "", err return "", err
} }
defer bin.Close() defer bin.Close()
// Get file info to retrieve the size
fileInfo, err := bin.Stat()
if err != nil {
return "", err
}
fileSize := fileInfo.Size()
hash := sha256.New() hash := sha256.New()
if _, err := io.Copy(hash, bin); err != nil { if _, err := io.Copy(hash, bin); err != nil {
return "", err return "", err
@@ -279,13 +301,157 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
return "", err return "", err
} }
var pw progressWriter
// Create a progress bar and start a goroutine to update it
// JK Let's use a percentage
//bar := progress.NewBar("transferring model data...", fileSize, 0)
//p.Add("transferring model data", bar)
status := "transferring model data 0%"
spinner.SetMessage(status)
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
spinner.SetMessage(fmt.Sprintf("transferring model data %d%%", int(100*pw.n/fileSize)))
case <-done:
spinner.SetMessage("transferring model data 100%")
return
}
}
}()
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil)) digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
// We check if we can find the models directory locally
// If we can, we return the path to the directory
// If we can't, we return an error
// If the blob exists already, we return the digest
dest, err := getLocalPath(cmd.Context(), digest)
if errors.Is(err, ErrBlobExists) {
return digest, nil
}
// Successfully found the model directory
if err == nil {
// Copy blob in via OS specific copy
// Linux errors out to use io.copy
err = localCopy(path, dest)
if err == nil {
return digest, nil
}
// Default copy using io.copy
err = defaultCopy(path, dest)
if err == nil {
return digest, nil
}
}
// If at any point copying the blob over locally fails, we default to the copy through the server
if err = client.CreateBlob(cmd.Context(), digest, io.TeeReader(bin, &pw)); err != nil {
return "", err return "", err
} }
return digest, nil return digest, nil
} }
type progressWriter struct {
n int64
}
func (w *progressWriter) Write(p []byte) (n int, err error) {
w.n += int64(len(p))
return len(p), nil
}
func getLocalPath(ctx context.Context, digest string) (string, error) {
ollamaHost := envconfig.Host
client := http.DefaultClient
base := &url.URL{
Scheme: ollamaHost.Scheme,
Host: net.JoinHostPort(ollamaHost.Host, ollamaHost.Port),
}
data, err := json.Marshal(digest)
if err != nil {
return "", err
}
reqBody := bytes.NewReader(data)
path := fmt.Sprintf("/api/blobs/%s", digest)
requestURL := base.JoinPath(path)
request, err := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), reqBody)
if err != nil {
return "", err
}
authz, err := api.Authorization(ctx, request)
if err != nil {
return "", err
}
request.Header.Set("Authorization", authz)
request.Header.Set("User-Agent", fmt.Sprintf("ollama/%s (%s %s) Go/%s", version.Version, runtime.GOARCH, runtime.GOOS, runtime.Version()))
request.Header.Set("X-Redirect-Create", "1")
resp, err := client.Do(request)
if err != nil {
return "", err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusTemporaryRedirect {
dest := resp.Header.Get("LocalLocation")
return dest, nil
}
return "", ErrBlobExists
}
func defaultCopy(path string, dest string) error {
// This function should be called if the server is local
// It should find the model directory, copy the blob over, and return the digest
dirPath := filepath.Dir(dest)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Copy blob over
sourceFile, err := os.Open(path)
if err != nil {
return fmt.Errorf("could not open source file: %v", err)
}
defer sourceFile.Close()
destFile, err := os.Create(dest)
if err != nil {
return fmt.Errorf("could not create destination file: %v", err)
}
defer destFile.Close()
_, err = io.CopyBuffer(destFile, sourceFile, make([]byte, 4*1024*1024))
if err != nil {
return fmt.Errorf("error copying file: %v", err)
}
err = destFile.Sync()
if err != nil {
return fmt.Errorf("error flushing file: %v", err)
}
return nil
}
func RunHandler(cmd *cobra.Command, args []string) error { func RunHandler(cmd *cobra.Command, args []string) error {
interactive := true interactive := true
@@ -379,11 +545,13 @@ func errFromUnknownKey(unknownKeyErr error) error {
if len(matches) > 0 { if len(matches) > 0 {
serverPubKey := matches[0] serverPubKey := matches[0]
localPubKey, err := auth.GetPublicKey() publicKey, err := auth.GetPublicKey()
if err != nil { if err != nil {
return unknownKeyErr return unknownKeyErr
} }
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(publicKey)))
if runtime.GOOS == "linux" && serverPubKey != localPubKey { if runtime.GOOS == "linux" && serverPubKey != localPubKey {
// try the ollama service public key // try the ollama service public key
svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub") svcPubKey, err := os.ReadFile("/usr/share/ollama/.ollama/id_ed25519.pub")
@@ -843,6 +1011,7 @@ type runOptions struct {
WordWrap bool WordWrap bool
Format string Format string
System string System string
Template string
Images []api.ImageData Images []api.ImageData
Options map[string]interface{} Options map[string]interface{}
MultiModal bool MultiModal bool
@@ -1036,6 +1205,7 @@ func generate(cmd *cobra.Command, opts runOptions) error {
Images: opts.Images, Images: opts.Images,
Format: opts.Format, Format: opts.Format,
System: opts.System, System: opts.System,
Template: opts.Template,
Options: opts.Options, Options: opts.Options,
KeepAlive: opts.KeepAlive, KeepAlive: opts.KeepAlive,
} }
@@ -1341,10 +1511,10 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NUM_PARALLEL"], envVars["OLLAMA_NUM_PARALLEL"],
envVars["OLLAMA_NOPRUNE"], envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"], envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_SCHED_SPREAD"],
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"], envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"], envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)

23
cmd/copy_darwin.go Normal file
View File

@@ -0,0 +1,23 @@
package cmd
import (
"os"
"path/filepath"
"golang.org/x/sys/unix"
)
func localCopy(src, target string) error {
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
err := unix.Clonefile(src, target, 0)
if err != nil {
return err
}
return nil
}

7
cmd/copy_linux.go Normal file
View File

@@ -0,0 +1,7 @@
package cmd
import "errors"
func localCopy(src, target string) error {
return errors.New("no local copy implementation for linux")
}

67
cmd/copy_windows.go Normal file
View File

@@ -0,0 +1,67 @@
//go:build windows
// +build windows
package cmd
import (
"os"
"path/filepath"
"syscall"
"unsafe"
)
func localCopy(src, target string) error {
// Create target directory if it doesn't exist
dirPath := filepath.Dir(target)
if err := os.MkdirAll(dirPath, 0o755); err != nil {
return err
}
// Open source file
sourceFile, err := os.Open(src)
if err != nil {
return err
}
defer sourceFile.Close()
// Create target file
targetFile, err := os.Create(target)
if err != nil {
return err
}
defer targetFile.Close()
// Use CopyFileExW to copy the file
err = copyFileEx(src, target)
if err != nil {
return err
}
return nil
}
func copyFileEx(src, dst string) error {
kernel32 := syscall.NewLazyDLL("kernel32.dll")
copyFileEx := kernel32.NewProc("CopyFileExW")
srcPtr, err := syscall.UTF16PtrFromString(src)
if err != nil {
return err
}
dstPtr, err := syscall.UTF16PtrFromString(dst)
if err != nil {
return err
}
r1, _, err := copyFileEx.Call(
uintptr(unsafe.Pointer(srcPtr)),
uintptr(unsafe.Pointer(dstPtr)),
0, 0, 0, 0)
if r1 == 0 {
return err
}
return nil
}

View File

@@ -1,7 +1,6 @@
package cmd package cmd
import ( import (
"cmp"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -10,14 +9,13 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"slices" "slices"
"sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/maps"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/parser"
"github.com/ollama/ollama/progress" "github.com/ollama/ollama/progress"
"github.com/ollama/ollama/readline" "github.com/ollama/ollama/readline"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
@@ -29,6 +27,7 @@ const (
MultilineNone MultilineState = iota MultilineNone MultilineState = iota
MultilinePrompt MultilinePrompt
MultilineSystem MultilineSystem
MultilineTemplate
) )
func loadModel(cmd *cobra.Command, opts *runOptions) error { func loadModel(cmd *cobra.Command, opts *runOptions) error {
@@ -95,6 +94,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter") fmt.Fprintln(os.Stderr, " /set parameter ... Set a parameter")
fmt.Fprintln(os.Stderr, " /set system <string> Set system message") fmt.Fprintln(os.Stderr, " /set system <string> Set system message")
fmt.Fprintln(os.Stderr, " /set template <string> Set prompt template")
fmt.Fprintln(os.Stderr, " /set history Enable history") fmt.Fprintln(os.Stderr, " /set history Enable history")
fmt.Fprintln(os.Stderr, " /set nohistory Disable history") fmt.Fprintln(os.Stderr, " /set nohistory Disable history")
fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap") fmt.Fprintln(os.Stderr, " /set wordwrap Enable wordwrap")
@@ -140,7 +140,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict") fmt.Fprintln(os.Stderr, " /set parameter num_predict <int> Max number of tokens to predict")
fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens") fmt.Fprintln(os.Stderr, " /set parameter top_k <int> Pick from top k num of tokens")
fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities") fmt.Fprintln(os.Stderr, " /set parameter top_p <float> Pick token based on sum of probabilities")
fmt.Fprintln(os.Stderr, " /set parameter min_p <float> Pick token based on top token probability * min_p")
fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size") fmt.Fprintln(os.Stderr, " /set parameter num_ctx <int> Set the context size")
fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level") fmt.Fprintln(os.Stderr, " /set parameter temperature <float> Set creativity level")
fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions") fmt.Fprintln(os.Stderr, " /set parameter repeat_penalty <float> How strongly to penalize repetitions")
@@ -205,6 +204,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System}) opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate:
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
multiline = MultilineNone multiline = MultilineNone
@@ -323,13 +326,17 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", ")) fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]] opts.Options[args[2]] = fp[args[2]]
case "system": case "system", "template":
if len(args) < 3 { if len(args) < 3 {
usageSet() usageSet()
continue continue
} }
multiline = MultilineSystem if args[1] == "system" {
multiline = MultilineSystem
} else if args[1] == "template" {
multiline = MultilineTemplate
}
line := strings.Join(args[2:], " ") line := strings.Join(args[2:], " ")
line, ok := strings.CutPrefix(line, `"""`) line, ok := strings.CutPrefix(line, `"""`)
@@ -349,17 +356,23 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
continue continue
} }
opts.System = sb.String() // for display in modelfile if args[1] == "system" {
newMessage := api.Message{Role: "system", Content: sb.String()} opts.System = sb.String() // for display in modelfile
// Check if the slice is not empty and the last message is from 'system' newMessage := api.Message{Role: "system", Content: sb.String()}
if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" { // Check if the slice is not empty and the last message is from 'system'
// Replace the last message if len(opts.Messages) > 0 && opts.Messages[len(opts.Messages)-1].Role == "system" {
opts.Messages[len(opts.Messages)-1] = newMessage // Replace the last message
} else { opts.Messages[len(opts.Messages)-1] = newMessage
opts.Messages = append(opts.Messages, newMessage) } else {
opts.Messages = append(opts.Messages, newMessage)
}
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
} }
fmt.Println("Set system message.")
sb.Reset()
sb.Reset() sb.Reset()
continue continue
@@ -378,9 +391,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
return err return err
} }
req := &api.ShowRequest{ req := &api.ShowRequest{
Name: opts.Model, Name: opts.Model,
System: opts.System, System: opts.System,
Options: opts.Options, Template: opts.Template,
Options: opts.Options,
} }
resp, err := client.Show(cmd.Context(), req) resp, err := client.Show(cmd.Context(), req)
if err != nil { if err != nil {
@@ -423,9 +437,12 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("No system message was specified for this model.") fmt.Println("No system message was specified for this model.")
} }
case "template": case "template":
if resp.Template != "" { switch {
case opts.Template != "":
fmt.Println(opts.Template + "\n")
case resp.Template != "":
fmt.Println(resp.Template) fmt.Println(resp.Template)
} else { default:
fmt.Println("No prompt template was specified for this model.") fmt.Println("No prompt template was specified for this model.")
} }
default: default:
@@ -509,35 +526,35 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
func buildModelfile(opts runOptions) string { func buildModelfile(opts runOptions) string {
var f parser.File var mf strings.Builder
f.Commands = append(f.Commands, parser.Command{Name: "model", Args: cmp.Or(opts.ParentModel, opts.Model)}) model := opts.ParentModel
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" { if opts.System != "" {
f.Commands = append(f.Commands, parser.Command{Name: "system", Args: opts.System}) fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
} }
keys := maps.Keys(opts.Options) if opts.Template != "" {
slices.Sort(keys) fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys { for _, k := range keys {
v := opts.Options[k] fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
var cmds []parser.Command
switch t := v.(type) {
case []string:
for _, s := range t {
cmds = append(cmds, parser.Command{Name: k, Args: s})
}
default:
cmds = append(cmds, parser.Command{Name: k, Args: fmt.Sprintf("%v", t)})
}
f.Commands = append(f.Commands, cmds...)
} }
fmt.Fprintln(&mf)
for _, msg := range opts.Messages { for _, msg := range opts.Messages {
f.Commands = append(f.Commands, parser.Command{Name: "message", Args: fmt.Sprintf("%s: %s", msg.Role, msg.Content)}) fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
} }
return f.String() return mf.String()
} }
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {

View File

@@ -1,10 +1,12 @@
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -55,53 +57,61 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
func TestModelfileBuilder(t *testing.T) { func TestModelfileBuilder(t *testing.T) {
opts := runOptions{ opts := runOptions{
Model: "hork", Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things", System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{ Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"}, {Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."}, {Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
}, },
Options: map[string]any{ Options: map[string]interface{}{},
"temperature": 0.9,
"seed": 42,
"penalize_newline": false,
"stop": []string{"hi", "there"},
},
} }
t.Run("model", func(t *testing.T) { opts.Options["temperature"] = 0.9
expect := `FROM hork opts.Options["seed"] = 42
SYSTEM You are part horse and part shark, but all hork. Do horklike things opts.Options["penalize_newline"] = false
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop hi PARAMETER stop [hi there]
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE assistant Yes it is true, I am half horse, half shark. MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
actual := buildModelfile(opts) tmpl, err := template.New("").Parse(expectedModelfile)
if diff := cmp.Diff(expect, actual); diff != "" { require.NoError(t, err)
t.Errorf("mismatch (-want +got):\n%s", diff)
}
})
t.Run("parent model", func(t *testing.T) { var buf bytes.Buffer
opts.ParentModel = "horseshark" err = tmpl.Execute(&buf, opts)
expect := `FROM horseshark require.NoError(t, err)
SYSTEM You are part horse and part shark, but all hork. Do horklike things assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false PARAMETER penalize_newline false
PARAMETER seed 42 PARAMETER seed 42
PARAMETER stop hi PARAMETER stop [hi there]
PARAMETER stop there
PARAMETER temperature 0.9 PARAMETER temperature 0.9
MESSAGE user Hey there hork!
MESSAGE assistant Yes it is true, I am half horse, half shark. MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
actual := buildModelfile(opts)
if diff := cmp.Diff(expect, actual); diff != "" { tmpl, err = template.New("").Parse(expectedModelfile)
t.Errorf("mismatch (-want +got):\n%s", diff) require.NoError(t, err)
}
}) var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
require.NoError(t, err)
assert.Equal(t, parentBuf.String(), mf)
} }

View File

@@ -71,11 +71,6 @@ func (m *MistralModel) WriteGGUF(ws io.WriteSeeker) error {
"tokenizer.ggml.unknown_token_id": uint32(0), "tokenizer.ggml.unknown_token_id": uint32(0),
} }
if m.Params.HeadDimension > 0 {
kv["llama.attention.key_length"] = uint32(m.Params.HeadDimension)
kv["llama.attention.value_length"] = uint32(m.Params.HeadDimension)
}
return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors) return llm.NewGGUFV3(m.Params.ByteOrder).Encode(ws, kv, m.Tensors)
} }

View File

@@ -40,7 +40,6 @@ Generate a response for a given prompt with a provided model. This is a streamin
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `prompt`: the prompt to generate a response for - `prompt`: the prompt to generate a response for
- `suffix`: the text after the model response
- `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`) - `images`: (optional) a list of base64-encoded images (for multimodal models such as `llava`)
Advanced parameters (optional): Advanced parameters (optional):
@@ -58,8 +57,7 @@ Advanced parameters (optional):
Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below. Enable JSON mode by setting the `format` parameter to `json`. This will structure the response as a valid JSON object. See the JSON mode [example](#request-json-mode) below.
> [!IMPORTANT] > Note: it's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
> It's important to instruct the model to use JSON in the `prompt`. Otherwise, the model may generate large amounts whitespace.
### Examples ### Examples
@@ -150,44 +148,8 @@ If `stream` is set to `false`, the response will be a single JSON object:
} }
``` ```
#### Request (with suffix)
##### Request
```shell
curl http://localhost:11434/api/generate -d '{
"model": "codellama:code",
"prompt": "def compute_gcd(a, b):",
"suffix": " return result",
"options": {
"temperature": 0
},
"stream": false
}'
```
##### Response
```json
{
"model": "codellama:code",
"created_at": "2024-07-22T20:47:51.147561Z",
"response": "\n if a == 0:\n return b\n else:\n return compute_gcd(b % a, a)\n\ndef compute_lcm(a, b):\n result = (a * b) / compute_gcd(a, b)\n",
"done": true,
"done_reason": "stop",
"context": [...],
"total_duration": 1162761250,
"load_duration": 6683708,
"prompt_eval_count": 17,
"prompt_eval_duration": 201222000,
"eval_count": 63,
"eval_duration": 953997000
}
```
#### Request (JSON mode) #### Request (JSON mode)
> [!IMPORTANT]
> When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON. > When `format` is set to `json`, the output will always be a well-formed JSON object. It's important to also instruct the model to respond in JSON.
##### Request ##### Request
@@ -336,7 +298,6 @@ curl http://localhost:11434/api/generate -d '{
"num_predict": 100, "num_predict": 100,
"top_k": 20, "top_k": 20,
"top_p": 0.9, "top_p": 0.9,
"min_p": 0.0,
"tfs_z": 0.5, "tfs_z": 0.5,
"typical_p": 0.7, "typical_p": 0.7,
"repeat_last_n": 33, "repeat_last_n": 33,
@@ -419,14 +380,12 @@ Generate the next message in a chat with a provided model. This is a streaming e
- `model`: (required) the [model name](#model-names) - `model`: (required) the [model name](#model-names)
- `messages`: the messages of the chat, this can be used to keep a chat memory - `messages`: the messages of the chat, this can be used to keep a chat memory
- `tools`: tools for the model to use if supported. Requires `stream` to be set to `false`
The `message` object has the following fields: The `message` object has the following fields:
- `role`: the role of the message, either `system`, `user`, `assistant`, or `tool` - `role`: the role of the message, either `system`, `user` or `assistant`
- `content`: the content of the message - `content`: the content of the message
- `images` (optional): a list of images to include in the message (for multimodal models such as `llava`) - `images` (optional): a list of images to include in the message (for multimodal models such as `llava`)
- `tool_calls` (optional): a list of tools the model wants to use
Advanced parameters (optional): Advanced parameters (optional):
@@ -587,7 +546,7 @@ Final response:
##### Request ##### Request
Send a chat message with images. The images should be provided as an array, with the individual images encoded in Base64. Send a chat message with a conversation history.
```shell ```shell
curl http://localhost:11434/api/chat -d '{ curl http://localhost:11434/api/chat -d '{
@@ -663,79 +622,6 @@ curl http://localhost:11434/api/chat -d '{
} }
``` ```
#### Chat request (with tools)
##### Request
```
curl http://localhost:11434/api/chat -d '{
"model": "mistral",
"messages": [
{
"role": "user",
"content": "What is the weather today in Paris?"
}
],
"stream": false,
"tools": [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The location to get the weather for, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"description": "The format to return the weather in, e.g. 'celsius' or 'fahrenheit'",
"enum": ["celsius", "fahrenheit"]
}
},
"required": ["location", "format"]
}
}
}
]
}'
```
##### Response
```json
{
"model": "mistral:7b-instruct-v0.3-q4_K_M",
"created_at": "2024-07-22T20:33:28.123648Z",
"message": {
"role": "assistant",
"content": "",
"tool_calls": [
{
"function": {
"name": "get_current_weather",
"arguments": {
"format": "celsius",
"location": "Paris, FR"
}
}
}
]
},
"done_reason": "stop",
"done": true,
"total_duration": 885095291,
"load_duration": 3753500,
"prompt_eval_count": 122,
"prompt_eval_duration": 328493000,
"eval_count": 33,
"eval_duration": 552222000
}
```
## Create a Model ## Create a Model
```shell ```shell
@@ -1140,7 +1026,7 @@ If `stream` is set to `false`, then the response is a single JSON object:
## Generate Embeddings ## Generate Embeddings
```shell ```shell
POST /api/embed POST /api/embeddings
``` ```
Generate embeddings from a model Generate embeddings from a model
@@ -1148,11 +1034,10 @@ Generate embeddings from a model
### Parameters ### Parameters
- `model`: name of model to generate embeddings from - `model`: name of model to generate embeddings from
- `input`: text or list of text to generate embeddings for - `prompt`: text to generate embeddings for
Advanced parameters: Advanced parameters:
- `truncate`: truncates the end of each input to fit within context length. Returns error if `false` and context length is exceeded. Defaults to `true`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature` - `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`) - `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
@@ -1161,9 +1046,9 @@ Advanced parameters:
#### Request #### Request
```shell ```shell
curl http://localhost:11434/api/embed -d '{ curl http://localhost:11434/api/embeddings -d '{
"model": "all-minilm", "model": "all-minilm",
"input": "Why is the sky blue?" "prompt": "Here is an article about llamas..."
}' }'
``` ```
@@ -1171,35 +1056,10 @@ curl http://localhost:11434/api/embed -d '{
```json ```json
{ {
"model": "all-minilm", "embedding": [
"embeddings": [[ 0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814, 0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348 ]
]]
}
```
#### Request (Multiple input)
```shell
curl http://localhost:11434/api/embed -d '{
"model": "all-minilm",
"input": ["Why is the sky blue?", "Why is the grass green?"]
}'
```
#### Response
```json
{
"model": "all-minilm",
"embeddings": [[
0.010071029, -0.0017594862, 0.05007221, 0.04692972, 0.054916814,
0.008599704, 0.105441414, -0.025878139, 0.12958129, 0.031952348
],[
-0.0098027075, 0.06042469, 0.025257962, -0.006364387, 0.07272725,
0.017194884, 0.09032035, -0.051705178, 0.09951512, 0.09072481
]]
} }
``` ```
@@ -1246,45 +1106,3 @@ A single JSON object will be returned.
] ]
} }
``` ```
## Generate Embedding
> Note: this endpoint has been superseded by `/api/embed`
```shell
POST /api/embeddings
```
Generate embeddings from a model
### Parameters
- `model`: name of model to generate embeddings from
- `prompt`: text to generate embeddings for
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples
#### Request
```shell
curl http://localhost:11434/api/embeddings -d '{
"model": "all-minilm",
"prompt": "Here is an article about llamas..."
}'
```
#### Response
```json
{
"embedding": [
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
]
}
```

View File

@@ -63,7 +63,7 @@ docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 114
Now you can run a model: Now you can run a model:
``` ```
docker exec -it ollama ollama run llama3.1 docker exec -it ollama ollama run llama3
``` ```
### Try different models ### Try different models

View File

@@ -227,7 +227,7 @@ curl http://localhost:11434/api/chat -d '{"model": "mistral"}'
To preload a model using the CLI, use the command: To preload a model using the CLI, use the command:
```shell ```shell
ollama run llama3.1 "" ollama run llama3 ""
``` ```
## How do I keep a model loaded in memory or make it unload immediately? ## How do I keep a model loaded in memory or make it unload immediately?
@@ -272,8 +272,4 @@ The following server settings may be used to adjust how Ollama handles concurren
- `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory. - `OLLAMA_NUM_PARALLEL` - The maximum number of parallel requests each model will process at the same time. The default will auto-select either 4 or 1 based on available memory.
- `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512 - `OLLAMA_MAX_QUEUE` - The maximum number of requests Ollama will queue when busy before rejecting additional requests. The default is 512
Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6.2 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM. Note: Windows with Radeon GPUs currently default to 1 model maximum due to limitations in ROCm v5.7 for available VRAM reporting. Once ROCm v6 is available, Windows Radeon will follow the defaults above. You may enable concurrent model loads on Radeon on Windows, but ensure you don't load more models than will fit into your GPUs VRAM.
## How does Ollama load models on multiple GPUs?
Installing multiple GPUs of the same brand can be a great way to increase your available VRAM to load larger models. When you load a new model, Ollama evaluates the required VRAM for the model against what is currently available. If the model will entirely fit on any single GPU, Ollama will load the model on that GPU. This typically provides the best performance as it reduces the amount of data transfering across the PCI bus during inference. If the model does not fit entirely on one GPU, then it will be spread across all the available GPUs.

View File

@@ -46,24 +46,13 @@ sudo modprobe nvidia_uvm`
## AMD Radeon ## AMD Radeon
Ollama supports the following AMD GPUs: Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators | | Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- | | -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` | | AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` | | AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` | | AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Windows Support ### Overrides
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In Ollama leverages the AMD ROCm library, which does not support all AMD GPUs. In
some cases you can force the system to try to use a similar LLVM target that is some cases you can force the system to try to use a similar LLVM target that is
close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4) close. For example The Radeon RX 5400 is `gfx1034` (also known as 10.3.4)
@@ -74,7 +63,7 @@ would set `HSA_OVERRIDE_GFX_VERSION="10.3.0"` as an environment variable for the
server. If you have an unsupported AMD GPU you can experiment using the list of server. If you have an unsupported AMD GPU you can experiment using the list of
supported types below. supported types below.
At this time, the known supported GPU types on linux are the following LLVM Targets. At this time, the known supported GPU types are the following LLVM Targets.
This table shows some example GPUs that map to these LLVM targets: This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** | | **LLVM Target** | **An Example GPU** |
|-----------------|---------------------| |-----------------|---------------------|

View File

@@ -1,7 +1,6 @@
# Ollama Model File # Ollama Model File
> [!NOTE] > Note: `Modelfile` syntax is in development
> `Modelfile` syntax is in development
A model file is the blueprint to create and share models with Ollama. A model file is the blueprint to create and share models with Ollama.
@@ -141,7 +140,6 @@ PARAMETER <parameter> <parametervalue>
| num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 | | num_predict | Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context) | int | num_predict 42 |
| top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 | | top_k | Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) | int | top_k 40 |
| top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 | | top_p | Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) | float | top_p 0.9 |
| min_p | Alternative to the top_p, and aims to ensure a balance of quality and variety. The parameter *p* represents the minimum probability for a token to be considered, relative to the probability of the most likely token. For example, with *p*=0.05 and the most likely token having a probability of 0.9, logits with a value less than 0.045 are filtered out. (Default: 0.0) | float | min_p 0.05 |
### TEMPLATE ### TEMPLATE

View File

@@ -78,8 +78,8 @@ curl http://localhost:11434/v1/chat/completions \
- [x] Streaming - [x] Streaming
- [x] JSON mode - [x] JSON mode
- [x] Reproducible outputs - [x] Reproducible outputs
- [x] Tools (streaming support coming soon)
- [ ] Vision - [ ] Vision
- [ ] Function calling
- [ ] Logprobs - [ ] Logprobs
#### Supported request fields #### Supported request fields
@@ -97,12 +97,16 @@ curl http://localhost:11434/v1/chat/completions \
- [x] `temperature` - [x] `temperature`
- [x] `top_p` - [x] `top_p`
- [x] `max_tokens` - [x] `max_tokens`
- [x] `tools`
- [ ] `tool_choice`
- [ ] `logit_bias` - [ ] `logit_bias`
- [ ] `tools`
- [ ] `tool_choice`
- [ ] `user` - [ ] `user`
- [ ] `n` - [ ] `n`
#### Notes
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models ## Models
Before using a model, pull it locally `ollama pull`: Before using a model, pull it locally `ollama pull`:

View File

@@ -1,173 +0,0 @@
# Template
Ollama provides a powerful templating engine backed by Go's built-in templating engine to construct prompts for your large language model. This feature is a valuable tool to get the most out of your models.
## Basic Template Structure
A basic Go template consists of three main parts:
* **Layout**: The overall structure of the template.
* **Variables**: Placeholders for dynamic data that will be replaced with actual values when the template is rendered.
* **Functions**: Custom functions or logic that can be used to manipulate the template's content.
Here's an example of a simple chat template:
```gotmpl
{{- range .Messages }}
{{ .Role }}: {{ .Content }}
{{- end }}
```
In this example, we have:
* A basic messages structure (layout)
* Three variables: `Messages`, `Role`, and `Content` (variables)
* A custom function (action) that iterates over an array of items (`range .Messages`) and displays each item
## Adding templates to your model
By default, models imported into Ollama have a default template of `{{ .Prompt }}`, i.e. user inputs are sent verbatim to the LLM. This is appropriate for text or code completion models but lacks essential markers for chat or instruction models.
Omitting a template in these models puts the responsibility of correctly templating input onto the user. Adding a template allows users to easily get the best results from the model.
To add templates in your model, you'll need to add a `TEMPLATE` command to the Modelfile. Here's an example using Meta's Llama 3.
```dockerfile
FROM llama3
TEMPLATE """{{- if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ .Content }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>
"""
```
## Variables
`System` (string): system prompt
`Prompt` (string): user prompt
`Response` (string): assistant response
`Suffix` (string): text inserted after the assistant's response
`Messages` (list): list of messages
`Messages[].Role` (string): role which can be one of `system`, `user`, `assistant`, or `tool`
`Messages[].Content` (string): message content
`Messages[].ToolCalls` (list): list of tools the model wants to call
`Messages[].ToolCalls[].Function` (object): function to call
`Messages[].ToolCalls[].Function.Name` (string): function name
`Messages[].ToolCalls[].Function.Arguments` (map): mapping of argument name to argument value
`Tools` (list): list of tools the model can access
`Tools[].Type` (string): schema type. `type` is always `function`
`Tools[].Function` (object): function definition
`Tools[].Function.Name` (string): function name
`Tools[].Function.Description` (string): function description
`Tools[].Function.Parameters` (object): function parameters
`Tools[].Function.Parameters.Type` (string): schema type. `type` is always `object`
`Tools[].Function.Parameters.Required` (list): list of required properties
`Tools[].Function.Parameters.Properties` (map): mapping of property name to property definition
`Tools[].Function.Parameters.Properties[].Type` (string): property type
`Tools[].Function.Parameters.Properties[].Description` (string): property description
`Tools[].Function.Parameters.Properties[].Enum` (list): list of valid values
## Tips and Best Practices
Keep the following tips and best practices in mind when working with Go templates:
* **Be mindful of dot**: Control flow structures like `range` and `with` changes the value `.`
* **Out-of-scope variables**: Use `$.` to reference variables not currently in scope, starting from the root
* **Whitespace control**: Use `-` to trim leading (`{{-`) and trailing (`-}}`) whitespace
## Examples
### Example Messages
#### ChatML
ChatML is a popular template format. It can be used for models such as Databrick's DBRX, Intel's Neural Chat, and Microsoft's Orca 2.
```gotmpl
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
```
### Example Tools
Tools support can be added to a model by adding a `{{ .Tools }}` node to the template. This feature is useful for models trained to call external tools and can a powerful tool for retrieving real-time data or performing complex tasks.
#### Mistral
Mistral v0.3 and Mixtral 8x22B supports tool calling.
```gotmpl
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (le (len (slice $.Messages $index)) 2) $.Tools }}[AVAILABLE_TOOLS] {{ json $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ json .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}
```
### Example Fill-in-Middle
Fill-in-middle support can be added to a model by adding a `{{ .Suffix }}` node to the template. This feature is useful for models that are trained to generate text in the middle of user input, such as code completion models.
#### CodeLlama
CodeLlama [7B](https://ollama.com/library/codellama:7b-code) and [13B](https://ollama.com/library/codellama:13b-code) code completion models support fill-in-middle.
```gotmpl
<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
```
> [!NOTE]
> CodeLlama 34B and 70B code completion and all instruct and Python fine-tuned models do not support fill-in-middle.
#### Codestral
Codestral [22B](https://ollama.com/library/codestral:22b) supports fill-in-middle.
```gotmpl
[SUFFIX]{{ .Suffix }}[PREFIX] {{ .Prompt }}
```

View File

@@ -15,7 +15,7 @@ import { Ollama } from "@langchain/community/llms/ollama";
const ollama = new Ollama({ const ollama = new Ollama({
baseUrl: "http://localhost:11434", baseUrl: "http://localhost:11434",
model: "llama3.1", model: "llama3",
}); });
const answer = await ollama.invoke(`why is the sky blue?`); const answer = await ollama.invoke(`why is the sky blue?`);
@@ -23,7 +23,7 @@ const answer = await ollama.invoke(`why is the sky blue?`);
console.log(answer); console.log(answer);
``` ```
That will get us the same thing as if we ran `ollama run llama3.1 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app. That will get us the same thing as if we ran `ollama run llama3 "why is the sky blue"` in the terminal. But we want to load a document from the web to ask a question against. **Cheerio** is a great library for ingesting a webpage, and **LangChain** uses it in their **CheerioWebBaseLoader**. So let's install **Cheerio** and build that part of the app.
```bash ```bash
npm install cheerio npm install cheerio

View File

@@ -23,8 +23,6 @@ Logs will often be helpful in diagnosing the problem (see
* NVIDIA 452.39 or newer Drivers if you have an NVIDIA card * NVIDIA 452.39 or newer Drivers if you have an NVIDIA card
* AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card * AMD Radeon Driver https://www.amd.com/en/support if you have a Radeon card
Ollama uses unicode characters for progress indication, which may render as unknown squares in some older terminal fonts in Windows 10. If you see this, try changing your terminal font settings.
## API Access ## API Access
Here's a quick example showing API access from `powershell` Here's a quick example showing API access from `powershell`

View File

@@ -43,6 +43,8 @@ var (
MaxRunners int MaxRunners int
// Set via OLLAMA_MAX_QUEUE in the environment // Set via OLLAMA_MAX_QUEUE in the environment
MaxQueuedRequests int MaxQueuedRequests int
// Set via OLLAMA_MAX_VRAM in the environment
MaxVRAM uint64
// Set via OLLAMA_MODELS in the environment // Set via OLLAMA_MODELS in the environment
ModelsDir string ModelsDir string
// Set via OLLAMA_NOHISTORY in the environment // Set via OLLAMA_NOHISTORY in the environment
@@ -87,6 +89,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models per GPU"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
@@ -191,6 +194,16 @@ func LoadConfig() {
TmpDir = clean("OLLAMA_TMPDIR") TmpDir = clean("OLLAMA_TMPDIR")
userLimit := clean("OLLAMA_MAX_VRAM")
if userLimit != "" {
avail, err := strconv.ParseUint(userLimit, 10, 64)
if err != nil {
slog.Error("invalid setting, ignoring", "OLLAMA_MAX_VRAM", userLimit, "error", err)
} else {
MaxVRAM = avail
}
}
LLMLibrary = clean("OLLAMA_LLM_LIBRARY") LLMLibrary = clean("OLLAMA_LLM_LIBRARY")
if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" { if onp := clean("OLLAMA_NUM_PARALLEL"); onp != "" {

View File

@@ -49,17 +49,9 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
} }
func commonAMDValidateLibDir() (string, error) { func commonAMDValidateLibDir() (string, error) {
// Favor our bundled version // We try to favor system paths first, so that we can wire up the subprocess to use
// the system version. Only use our bundled version if the system version doesn't work
// Installer payload location if we're running the installed binary // This gives users a more recovery options if versions have subtle problems at runtime
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
// Prefer explicit HIP env var // Prefer explicit HIP env var
hipPath := os.Getenv("HIP_PATH") hipPath := os.Getenv("HIP_PATH")
@@ -95,5 +87,14 @@ func commonAMDValidateLibDir() (string, error) {
} }
} }
// Installer payload location if we're running the installed binary
exe, err := os.Executable()
if err == nil {
rocmTargetDir := filepath.Join(filepath.Dir(exe), "rocm")
if rocmLibUsable(rocmTargetDir) {
slog.Debug("detected ROCM next to ollama executable " + rocmTargetDir)
return rocmTargetDir, nil
}
}
return "", fmt.Errorf("no suitable rocm found, falling back to CPU") return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
} }

View File

@@ -33,10 +33,9 @@ type HipLib struct {
} }
func NewHipLib() (*HipLib, error) { func NewHipLib() (*HipLib, error) {
// At runtime we depend on v6, so discover GPUs with the same library for a consistent set of GPUs h, err := windows.LoadLibrary("amdhip64.dll")
h, err := windows.LoadLibrary("amdhip64_6.dll")
if err != nil { if err != nil {
return nil, fmt.Errorf("unable to load amdhip64_6.dll, please make sure to upgrade to the latest amd driver: %w", err) return nil, fmt.Errorf("unable to load amdhip64.dll: %w", err)
} }
hl := &HipLib{} hl := &HipLib{}
hl.dll = h hl.dll = h
@@ -85,8 +84,9 @@ func (hl *HipLib) AMDDriverVersion() (driverMajor, driverMinor int, err error) {
} }
slog.Debug("hipDriverGetVersion", "version", version) slog.Debug("hipDriverGetVersion", "version", version)
driverMajor = version / 10000000 // TODO - this isn't actually right, but the docs claim hipDriverGetVersion isn't accurate anyway...
driverMinor = (version - (driverMajor * 10000000)) / 100000 driverMajor = version / 1000
driverMinor = (version - (driverMajor * 1000)) / 10
return driverMajor, driverMinor, nil return driverMajor, driverMinor, nil
} }

View File

@@ -10,7 +10,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"slices" "slices"
"sort"
"strconv" "strconv"
"strings" "strings"
@@ -83,20 +82,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
// The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract // The amdgpu driver always exposes the host CPU(s) first, but we have to skip them and subtract
// from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU) // from the other IDs to get alignment with the HIP libraries expectations (zero is the first GPU, not the CPU)
matches, _ := filepath.Glob(GPUPropertiesFileGlob) matches, _ := filepath.Glob(GPUPropertiesFileGlob)
sort.Slice(matches, func(i, j int) bool {
// /sys/class/kfd/kfd/topology/nodes/<number>/properties
a, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[i])), 10, 64)
if err != nil {
slog.Debug("parse err", "error", err, "match", matches[i])
return false
}
b, err := strconv.ParseInt(filepath.Base(filepath.Dir(matches[j])), 10, 64)
if err != nil {
slog.Debug("parse err", "error", err, "match", matches[i])
return false
}
return a < b
})
cpuCount := 0 cpuCount := 0
for _, match := range matches { for _, match := range matches {
slog.Debug("evaluating amdgpu node " + match) slog.Debug("evaluating amdgpu node " + match)

View File

@@ -22,8 +22,8 @@ const (
var ( var (
// Used to validate if the given ROCm lib is usable // Used to validate if the given ROCm lib is usable
ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // This is not sufficient to discern v5 vs v6 ROCmLibGlobs = []string{"hipblas.dll", "rocblas"} // TODO - probably include more coverage of files here...
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\6.1\\bin"} // TODO glob? RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
) )
func AMDGetGPUInfo() []RocmGPUInfo { func AMDGetGPUInfo() []RocmGPUInfo {
@@ -35,11 +35,12 @@ func AMDGetGPUInfo() []RocmGPUInfo {
} }
defer hl.Release() defer hl.Release()
driverMajor, driverMinor, err := hl.AMDDriverVersion() // TODO - this reports incorrect version information, so omitting for now
if err != nil { // driverMajor, driverMinor, err := hl.AMDDriverVersion()
// For now this is benign, but we may eventually need to fail compatibility checks // if err != nil {
slog.Debug("error looking up amd driver version", "error", err) // // For now this is benign, but we may eventually need to fail compatibility checks
} // slog.Debug("error looking up amd driver version", "error", err)
// }
// Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified // Note: the HIP library automatically handles subsetting to any HIP_VISIBLE_DEVICES the user specified
count := hl.HipGetDeviceCount() count := hl.HipGetDeviceCount()
@@ -92,8 +93,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
continue continue
} }
if gfxOverride == "" { if gfxOverride == "" {
// Strip off Target Features when comparing if !slices.Contains[[]string, string](supported, gfx) {
if !slices.Contains[[]string, string](supported, strings.Split(gfx, ":")[0]) {
slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported) slog.Warn("amdgpu is not supported", "gpu", i, "gpu_type", gfx, "library", libDir, "supported_types", supported)
// TODO - consider discrete markdown just for ROCM troubleshooting? // TODO - consider discrete markdown just for ROCM troubleshooting?
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage") slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
@@ -132,8 +132,10 @@ func AMDGetGPUInfo() []RocmGPUInfo {
MinimumMemory: rocmMinimumMemory, MinimumMemory: rocmMinimumMemory,
Name: name, Name: name,
Compute: gfx, Compute: gfx,
DriverMajor: driverMajor,
DriverMinor: driverMinor, // TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
// DriverMajor: driverMajor,
// DriverMinor: driverMinor,
}, },
index: i, index: i,
} }

View File

@@ -274,28 +274,6 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMajor = driverMajor
gpuInfo.DriverMinor = driverMinor gpuInfo.DriverMinor = driverMinor
// query the management library as well so we can record any skew between the two
// which represents overhead on the GPU we must set aside on subsequent updates
if cHandles.nvml != nil {
C.nvml_get_free(*cHandles.nvml, C.int(gpuInfo.index), &memInfo.free, &memInfo.total, &memInfo.used)
if memInfo.err != nil {
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
if memInfo.free != 0 && uint64(memInfo.free) > gpuInfo.FreeMemory {
gpuInfo.OSOverhead = uint64(memInfo.free) - gpuInfo.FreeMemory
slog.Info("detected OS VRAM overhead",
"id", gpuInfo.ID,
"library", gpuInfo.Library,
"compute", gpuInfo.Compute,
"driver", fmt.Sprintf("%d.%d", gpuInfo.DriverMajor, gpuInfo.DriverMinor),
"name", gpuInfo.Name,
"overhead", format.HumanBytes2(gpuInfo.OSOverhead),
)
}
}
}
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
cudaGPUs = append(cudaGPUs, gpuInfo) cudaGPUs = append(cudaGPUs, gpuInfo)
} }
@@ -360,17 +338,14 @@ func GetGPUInfo() GpuInfoList {
"before", "before",
"total", format.HumanBytes2(cpus[0].TotalMemory), "total", format.HumanBytes2(cpus[0].TotalMemory),
"free", format.HumanBytes2(cpus[0].FreeMemory), "free", format.HumanBytes2(cpus[0].FreeMemory),
"free_swap", format.HumanBytes2(cpus[0].FreeSwap),
), ),
slog.Group( slog.Group(
"now", "now",
"total", format.HumanBytes2(mem.TotalMemory), "total", format.HumanBytes2(mem.TotalMemory),
"free", format.HumanBytes2(mem.FreeMemory), "free", format.HumanBytes2(mem.FreeMemory),
"free_swap", format.HumanBytes2(mem.FreeSwap),
), ),
) )
cpus[0].FreeMemory = mem.FreeMemory cpus[0].FreeMemory = mem.FreeMemory
cpus[0].FreeSwap = mem.FreeSwap
} }
var memInfo C.mem_info_t var memInfo C.mem_info_t
@@ -399,14 +374,9 @@ func GetGPUInfo() GpuInfoList {
slog.Warn("error looking up nvidia GPU memory") slog.Warn("error looking up nvidia GPU memory")
continue continue
} }
if cHandles.nvml != nil && gpu.OSOverhead > 0 {
// When using the management library update based on recorded overhead
memInfo.free -= C.uint64_t(gpu.OSOverhead)
}
slog.Debug("updating cuda memory data", slog.Debug("updating cuda memory data",
"gpu", gpu.ID, "gpu", gpu.ID,
"name", gpu.Name, "name", gpu.Name,
"overhead", format.HumanBytes2(gpu.OSOverhead),
slog.Group( slog.Group(
"before", "before",
"total", format.HumanBytes2(gpu.TotalMemory), "total", format.HumanBytes2(gpu.TotalMemory),

View File

@@ -57,7 +57,6 @@ func GetCPUMem() (memInfo, error) {
return memInfo{ return memInfo{
TotalMemory: uint64(C.getPhysicalMemory()), TotalMemory: uint64(C.getPhysicalMemory()),
FreeMemory: uint64(C.getFreeMemory()), FreeMemory: uint64(C.getFreeMemory()),
// FreeSwap omitted as Darwin uses dynamic paging
}, nil }, nil
} }

View File

@@ -50,7 +50,7 @@ var OneapiMgmtName = "libze_intel_gpu.so"
func GetCPUMem() (memInfo, error) { func GetCPUMem() (memInfo, error) {
var mem memInfo var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64 var total, available, free, buffers, cached uint64
f, err := os.Open("/proc/meminfo") f, err := os.Open("/proc/meminfo")
if err != nil { if err != nil {
return mem, err return mem, err
@@ -70,21 +70,20 @@ func GetCPUMem() (memInfo, error) {
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers) _, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
case strings.HasPrefix(line, "Cached:"): case strings.HasPrefix(line, "Cached:"):
_, err = fmt.Sscanf(line, "Cached:%d", &cached) _, err = fmt.Sscanf(line, "Cached:%d", &cached)
case strings.HasPrefix(line, "SwapFree:"):
_, err = fmt.Sscanf(line, "SwapFree:%d", &freeSwap)
default: default:
continue continue
} }
if err != nil { if err != nil {
return mem, err return mem, err
} }
if total > 0 && available > 0 {
mem.TotalMemory = total * format.KibiByte
mem.FreeMemory = available * format.KibiByte
return mem, nil
}
} }
mem.TotalMemory = total * format.KibiByte mem.TotalMemory = total * format.KibiByte
mem.FreeSwap = freeSwap * format.KibiByte mem.FreeMemory = (free + buffers + cached) * format.KibiByte
if available > 0 {
mem.FreeMemory = available * format.KibiByte
} else {
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
}
return mem, nil return mem, nil
} }

View File

@@ -51,5 +51,5 @@ func GetCPUMem() (memInfo, error) {
if r1 == 0 { if r1 == 0 {
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err) return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
} }
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys, FreeSwap: memStatus.AvailPageFile}, nil return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
} }

View File

@@ -10,7 +10,6 @@ import (
type memInfo struct { type memInfo struct {
TotalMemory uint64 `json:"total_memory,omitempty"` TotalMemory uint64 `json:"total_memory,omitempty"`
FreeMemory uint64 `json:"free_memory,omitempty"` FreeMemory uint64 `json:"free_memory,omitempty"`
FreeSwap uint64 `json:"free_swap,omitempty"`
} }
// Beginning of an `ollama info` command // Beginning of an `ollama info` command
@@ -53,8 +52,7 @@ type CPUInfo struct {
type CudaGPUInfo struct { type CudaGPUInfo struct {
GpuInfo GpuInfo
OSOverhead uint64 // Memory overhead between the driver library and management library index int //nolint:unused,nolintlint
index int //nolint:unused,nolintlint
} }
type CudaGPUInfoList []CudaGPUInfo type CudaGPUInfoList []CudaGPUInfo

View File

@@ -69,7 +69,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
reqLimit := len(req) reqLimit := len(req)
iterLimit := 5 iterLimit := 5
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram != "" { if vram != "" {
max, err := strconv.ParseUint(vram, 10, 64) max, err := strconv.ParseUint(vram, 10, 64)
require.NoError(t, err) require.NoError(t, err)
@@ -106,7 +106,7 @@ func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
// Stress the system if we know how much VRAM it has, and attempt to load more models than will fit // Stress the system if we know how much VRAM it has, and attempt to load more models than will fit
func TestMultiModelStress(t *testing.T) { func TestMultiModelStress(t *testing.T) {
vram := os.Getenv("OLLAMA_MAX_VRAM") // TODO - discover actual VRAM vram := os.Getenv("OLLAMA_MAX_VRAM")
if vram == "" { if vram == "" {
t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test") t.Skip("OLLAMA_MAX_VRAM not specified, can't pick the right models for the stress test")
} }

View File

@@ -12,7 +12,7 @@ import (
func TestContextExhaustion(t *testing.T) { func TestContextExhaustion(t *testing.T) {
// Longer needed for small footprint GPUs // Longer needed for small footprint GPUs
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute) ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
defer cancel() defer cancel()
// Set up the test data // Set up the test data
req := api.GenerateRequest{ req := api.GenerateRequest{
@@ -25,10 +25,5 @@ func TestContextExhaustion(t *testing.T) {
"num_ctx": 128, "num_ctx": 128,
}, },
} }
client, _, cleanup := InitServerConnection(ctx, t) GenerateTestHelper(ctx, t, req, []string{"once", "upon", "lived"})
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("PullIfMissing failed: %v", err)
}
DoGenerate(ctx, t, client, req, []string{"once", "upon", "lived"}, 120*time.Second, 10*time.Second)
} }

View File

@@ -1,201 +0,0 @@
//go:build integration
package integration
import (
"context"
"math"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func floatsEqual32(a, b float32) bool {
return math.Abs(float64(a-b)) <= 1e-4
}
func floatsEqual64(a, b float64) bool {
return math.Abs(a-b) <= 1e-4
}
func TestAllMiniLMEmbeddings(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbeddingRequest{
Model: "all-minilm",
Prompt: "why is the sky blue?",
}
res, err := embeddingTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embedding) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embedding))
}
if !floatsEqual64(res.Embedding[0], 0.06642947345972061) {
t.Fatalf("expected 0.06642947345972061, got %.16f", res.Embedding[0])
}
}
func TestAllMiniLMEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 1 {
t.Fatalf("expected 1 embedding, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) {
t.Fatalf("expected 0.010071031, got %.8f", res.Embeddings[0][0])
}
}
func TestAllMiniLMBatchEmbed(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
req := api.EmbedRequest{
Model: "all-minilm",
Input: []string{"why is the sky blue?", "why is the grass green?"},
}
res, err := embedTestHelper(ctx, t, req)
if err != nil {
t.Fatalf("error: %v", err)
}
if len(res.Embeddings) != 2 {
t.Fatalf("expected 2 embeddings, got %d", len(res.Embeddings))
}
if len(res.Embeddings[0]) != 384 {
t.Fatalf("expected 384 floats, got %d", len(res.Embeddings[0]))
}
if !floatsEqual32(res.Embeddings[0][0], 0.010071031) || !floatsEqual32(res.Embeddings[1][0], -0.009802706) {
t.Fatalf("expected 0.010071031 and -0.009802706, got %.8f and %.8f", res.Embeddings[0][0], res.Embeddings[1][0])
}
}
func TestAllMiniLMEmbedTruncate(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
truncTrue, truncFalse := true, false
type testReq struct {
Name string
Request api.EmbedRequest
}
reqs := []testReq{
{
Name: "Target Truncation",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why",
},
},
{
Name: "Default Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 1},
},
},
{
Name: "Explicit Truncate",
Request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
},
}
res := make(map[string]*api.EmbedResponse)
for _, req := range reqs {
response, err := embedTestHelper(ctx, t, req.Request)
if err != nil {
t.Fatalf("error: %v", err)
}
res[req.Name] = response
}
if res["Target Truncation"].Embeddings[0][0] != res["Default Truncate"].Embeddings[0][0] {
t.Fatal("expected default request to truncate correctly")
}
if res["Default Truncate"].Embeddings[0][0] != res["Explicit Truncate"].Embeddings[0][0] {
t.Fatal("expected default request and truncate true request to be the same")
}
// check that truncate set to false returns an error if context length is exceeded
_, err := embedTestHelper(ctx, t, api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 1},
})
if err == nil {
t.Fatal("expected error, got nil")
}
}
func embeddingTestHelper(ctx context.Context, t *testing.T, req api.EmbeddingRequest) (*api.EmbeddingResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embeddings(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}
func embedTestHelper(ctx context.Context, t *testing.T, req api.EmbedRequest) (*api.EmbedResponse, error) {
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, req.Model); err != nil {
t.Fatalf("failed to pull model %s: %v", req.Model, err)
}
response, err := client.Embed(ctx, &req)
if err != nil {
return nil, err
}
return response, nil
}

View File

@@ -41,7 +41,6 @@
#if defined(_WIN32) #if defined(_WIN32)
#include <windows.h> #include <windows.h>
#include <errhandlingapi.h>
#endif #endif
#include <cstddef> #include <cstddef>
@@ -2438,6 +2437,15 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i])); params.lora_adapter.emplace_back(lora_adapter, std::stof(argv[i]));
params.use_mmap = false; params.use_mmap = false;
} }
else if (arg == "--lora-base")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
params.lora_base = argv[i];
}
else if (arg == "-v" || arg == "--verbose") else if (arg == "-v" || arg == "--verbose")
{ {
server_verbose = true; server_verbose = true;
@@ -2729,9 +2737,6 @@ int wmain(int argc, wchar_t **wargv) {
for (int i = 0; i < argc; ++i) { for (int i = 0; i < argc; ++i) {
argv[i] = wchar_to_char(wargv[i]); argv[i] = wchar_to_char(wargv[i]);
} }
// Adjust error mode to avoid error dialog after we start.
SetErrorMode(SEM_FAILCRITICALERRORS);
#else #else
int main(int argc, char **argv) { int main(int argc, char **argv) {
#endif #endif
@@ -3183,33 +3188,26 @@ int main(int argc, char **argv) {
prompt = ""; prompt = "";
} }
if (prompt.size() == 1) { json image_data;
prompt = prompt[0]; if (body.count("image_data") != 0) {
image_data = body["image_data"];
}
else
{
image_data = "";
} }
// create and queue the task // create and queue the task
json responses; const int task_id = llama.queue_tasks.get_new_id();
{ llama.queue_results.add_waiting_task_id(task_id);
const int id_task = llama.queue_tasks.get_new_id(); llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1);
llama.queue_results.add_waiting_task_id(id_task);
llama.request_completion(id_task, {{"prompt", prompt}}, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(id_task); task_result result = llama.queue_results.recv(task_id);
llama.queue_results.remove_waiting_task_id(id_task); llama.queue_results.remove_waiting_task_id(task_id);
if (result.error) {
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
}
responses = result.result_json.value("results", std::vector<json>{result.result_json}); // send the result
json embeddings = json::array(); return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
for (auto & elem : responses) {
embeddings.push_back(elem.at("embedding"));
}
// send the result
json embedding_res = json{{"embedding", embeddings}};
return res.set_content(embedding_res.dump(), "application/json; charset=utf-8");
}
}); });
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!? // GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?

View File

@@ -178,7 +178,7 @@ if [ -z "${OLLAMA_SKIP_CUDA_GENERATE}" -a -d "${CUDA_LIB_DIR}" ]; then
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${OLLAMA_CUSTOM_CUDA_DEFS}"
echo "Building custom CUDA GPU" echo "Building custom CUDA GPU"
else else
CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES}" CMAKE_CUDA_DEFS="-DGGML_CUDA=on -DCMAKE_CUDA_FLAGS=-t8 -DGGML_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} -DCMAKE_LIBRARY_PATH=/usr/local/cuda/compat"
fi fi
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}" CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS} ${CMAKE_CUDA_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"

View File

@@ -6,9 +6,18 @@ function amdGPUs {
if ($env:AMDGPU_TARGETS) { if ($env:AMDGPU_TARGETS) {
return $env:AMDGPU_TARGETS return $env:AMDGPU_TARGETS
} }
# Current supported rocblas list from ROCm v6.1.2 on windows # TODO - load from some common data file for linux + windows build consistency
# https://rocm.docs.amd.com/projects/install-on-windows/en/latest/reference/system-requirements.html#windows-supported-gpus
$GPU_LIST = @( $GPU_LIST = @(
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
"gfx90a:xnack+"
"gfx90a:xnack-"
"gfx940"
"gfx941"
"gfx942"
"gfx1010"
"gfx1012"
"gfx1030" "gfx1030"
"gfx1100" "gfx1100"
"gfx1101" "gfx1101"
@@ -386,6 +395,7 @@ function build_rocm() {
sign sign
install install
# Assumes v5.7, may need adjustments for v6
rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" rm -ea 0 -recurse -force -path "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null md "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\" -ea 0 > $null
cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\" cp "${env:HIP_PATH}\bin\hipblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"

View File

@@ -424,32 +424,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
4*batch*(3*embedding+vocab)+embedding*vocab*105/128, 4*batch*(3*embedding+vocab)+embedding*vocab*105/128,
4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16, 4*batch*(2*embedding+1+2*embeddingHeadsK*headsKV+context+context*headsKV)+4*embeddingHeadsK*context*headsKV+embedding*embeddingHeadsK*headsKV*9/16,
) )
case "chatglm":
fullOffload = 4 * batch * (embedding + vocab)
partialOffload = 4*batch*(embedding+vocab) + embedding*vocab*105/128
if qkvBias, ok := layers["blk.0"]["attn_qkv.bias"]; ok {
fullOffload = max(
fullOffload,
4*batch*(2+
2*embedding+
context+
context*heads+
embeddingHeadsK*heads+
qkvBias.Shape[0]),
)
partialOffload = max(
partialOffload,
4*batch*(1+
2*embedding+
embeddingHeadsK*heads+
context+
context*heads)+
4*embeddingHeadsK*context+
4*context*embeddingHeadsK+
4*qkvBias.Shape[0],
)
}
} }
return return

View File

@@ -537,7 +537,6 @@ var ggufKVOrder = map[string][]string{
"tokenizer.ggml.add_bos_token", "tokenizer.ggml.add_bos_token",
"tokenizer.ggml.add_eos_token", "tokenizer.ggml.add_eos_token",
"tokenizer.chat_template", "tokenizer.chat_template",
"bert.pooling_type",
}, },
} }

View File

@@ -10,10 +10,17 @@ package llm
// #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src // #cgo linux,arm64 LDFLAGS: -L${SRCDIR}/build/linux/arm64_static -L${SRCDIR}/build/linux/arm64_static/src -L${SRCDIR}/build/linux/arm64_static/ggml/src
// #include <stdlib.h> // #include <stdlib.h>
// #include "llama.h" // #include "llama.h"
// bool update_quantize_progress(float progress, void* data) {
// *((float*)data) = progress;
// return true;
// }
import "C" import "C"
import ( import (
"fmt" "fmt"
"unsafe" "unsafe"
"time"
"github.com/ollama/ollama/api"
) )
// SystemInfo is an unused example of calling llama.cpp functions using CGo // SystemInfo is an unused example of calling llama.cpp functions using CGo
@@ -21,7 +28,7 @@ func SystemInfo() string {
return C.GoString(C.llama_print_system_info()) return C.GoString(C.llama_print_system_info())
} }
func Quantize(infile, outfile string, ftype fileType) error { func Quantize(infile, outfile string, ftype fileType, fn func(resp api.ProgressResponse), tensorCount int) error {
cinfile := C.CString(infile) cinfile := C.CString(infile)
defer C.free(unsafe.Pointer(cinfile)) defer C.free(unsafe.Pointer(cinfile))
@@ -32,8 +39,42 @@ func Quantize(infile, outfile string, ftype fileType) error {
params.nthread = -1 params.nthread = -1
params.ftype = ftype.Value() params.ftype = ftype.Value()
// Initialize "global" to store progress
store := C.malloc(C.sizeof_float)
defer C.free(unsafe.Pointer(store))
// Initialize store value, e.g., setting initial progress to 0
*(*C.float)(store) = 0.0
params.quantize_callback_data = store
params.quantize_callback = (C.llama_progress_callback)(C.update_quantize_progress)
ticker := time.NewTicker(60 * time.Millisecond)
done := make(chan struct{})
defer close(done)
go func() {
defer ticker.Stop()
for {
select {
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d/%d", int(*((*C.float)(store))), tensorCount),
Quantize: "quant",
})
fmt.Println("Progress: ", *((*C.float)(store)))
case <-done:
fn(api.ProgressResponse{
Status: fmt.Sprintf("quantizing model %d/%d", tensorCount, tensorCount),
Quantize: "quant",
})
return
}
}
}()
if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 { if rc := C.llama_model_quantize(cinfile, coutfile, &params); rc != 0 {
return fmt.Errorf("failed to quantize model. This model architecture may not be supported, or you may need to upgrade Ollama to the latest version") return fmt.Errorf("llama_model_quantize: %d", rc)
} }
return nil return nil

View File

@@ -2,10 +2,7 @@ package llm
import ( import (
"embed" "embed"
"syscall"
) )
//go:embed build/darwin/x86_64/*/bin/* //go:embed build/darwin/x86_64/*/bin/*
var libEmbed embed.FS var libEmbed embed.FS
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}

View File

@@ -2,10 +2,7 @@ package llm
import ( import (
"embed" "embed"
"syscall"
) )
//go:embed build/darwin/arm64/*/bin/* //go:embed build/darwin/arm64/*/bin/*
var libEmbed embed.FS var libEmbed embed.FS
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}

View File

@@ -1,11 +1,6 @@
package llm package llm
import ( import "embed"
"embed"
"syscall"
)
//go:embed build/linux/*/*/bin/* //go:embed build/linux/*/*/bin/*
var libEmbed embed.FS var libEmbed embed.FS
var LlamaServerSysProcAttr = &syscall.SysProcAttr{}

View File

@@ -1,20 +1,6 @@
package llm package llm
import ( import "embed"
"embed"
"syscall"
)
// unused on windows // unused on windows
var libEmbed embed.FS var libEmbed embed.FS
const CREATE_DEFAULT_ERROR_MODE = 0x04000000
var LlamaServerSysProcAttr = &syscall.SysProcAttr{
// Wire up the default error handling logic If for some reason a DLL is
// missing in the path this will pop up a GUI Dialog explaining the fault so
// the user can either fix their PATH, or report a bug. Without this
// setting, the process exits immediately with a generic exit status but no
// way to (easily) figure out what the actual missing DLL was.
CreationFlags: CREATE_DEFAULT_ERROR_MODE,
}

View File

@@ -1,8 +1,8 @@
diff --git a/src/llama.cpp b/src/llama.cpp diff --git a/src/llama.cpp b/src/llama.cpp
index a207451f..2ddf431d 100644 index 2b9ace28..172640e2 100644
--- a/src/llama.cpp --- a/src/llama.cpp
+++ b/src/llama.cpp +++ b/src/llama.cpp
@@ -5347,16 +5347,7 @@ static void llm_load_vocab( @@ -5357,16 +5357,7 @@ static void llm_load_vocab(
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) { if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
vocab.tokenizer_add_space_prefix = false; vocab.tokenizer_add_space_prefix = false;
vocab.tokenizer_clean_spaces = true; vocab.tokenizer_clean_spaces = true;
@@ -20,9 +20,9 @@ index a207451f..2ddf431d 100644
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
tokenizer_pre == "llama3" || tokenizer_pre == "llama3" ||
@@ -5443,7 +5434,8 @@ static void llm_load_vocab( @@ -5439,7 +5430,8 @@ static void llm_load_vocab(
tokenizer_pre == "codeshell") { tokenizer_pre == "jais") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_CODESHELL; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_JAIS;
} else { } else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__); + LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);

13
llm/patches/06-qwen2.diff Normal file
View File

@@ -0,0 +1,13 @@
diff --git a/src/llama.cpp b/src/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

View File

@@ -1,358 +0,0 @@
diff --git a/common/common.cpp b/common/common.cpp
index dbb724fb..c26fe6ee 100644
--- a/common/common.cpp
+++ b/common/common.cpp
@@ -2087,14 +2087,27 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
for (unsigned int i = 0; i < params.lora_adapter.size(); ++i) {
const std::string & lora_adapter = std::get<0>(params.lora_adapter[i]);
float lora_scale = std::get<1>(params.lora_adapter[i]);
+
+ // try to load as gguf
auto adapter = llama_lora_adapter_init(model, lora_adapter.c_str());
if (adapter == nullptr) {
- fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
- llama_free(lctx);
- llama_free_model(model);
- return std::make_tuple(nullptr, nullptr);
+ fprintf(stderr, "%s: error: failed to apply lora adapter, trying ggla\n", __func__);
+
+ // if that fails, try loading as ggla for compatibility
+ int err = llama_model_apply_lora_from_file(model,
+ lora_adapter.c_str(),
+ lora_scale,
+ nullptr,
+ params.n_threads);
+ if (err != 0) {
+ fprintf(stderr, "%s: error: failed to apply lora adapter\n", __func__);
+ llama_free(lctx);
+ llama_free_model(model);
+ return std::make_tuple(nullptr, nullptr);
+ }
+ } else {
+ llama_lora_adapter_set(lctx, adapter, lora_scale);
}
- llama_lora_adapter_set(lctx, adapter, lora_scale);
}
if (params.ignore_eos) {
diff --git a/include/llama.h b/include/llama.h
index 93fd77ca..b0fb37a6 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -1160,6 +1160,20 @@ extern "C" {
LLAMA_API void llama_dump_timing_info_yaml(FILE * stream, const struct llama_context * ctx);
+ // Apply a LoRA adapter to a loaded model
+ // path_base_model is the path to a higher quality model to use as a base for
+ // the layers modified by the adapter. Can be NULL to use the current loaded model.
+ // The model needs to be reloaded before applying a new adapter, otherwise the adapter
+ // will be applied on top of the previous one
+ // Returns 0 on success
+ LLAMA_API int32_t llama_model_apply_lora_from_file(
+ const struct llama_model * model,
+ const char * path_lora,
+ float scale,
+ const char * path_base_model,
+ int32_t n_threads);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/src/llama.cpp b/src/llama.cpp
index 80a0dd0f..9d7b0e17 100644
--- a/src/llama.cpp
+++ b/src/llama.cpp
@@ -21880,3 +21880,290 @@ static void llama_log_callback_default(ggml_log_level level, const char * text,
fputs(text, stderr);
fflush(stderr);
}
+
+static int llama_apply_lora_from_file_internal(
+ const struct llama_model & model, const char * path_lora, float scale, const char * path_base_model, int n_threads
+) {
+ LLAMA_LOG_INFO("%s: applying lora adapter from '%s' - please wait ...\n", __func__, path_lora);
+
+ const int64_t t_start_lora_us = ggml_time_us();
+
+ llama_file fin(path_lora, "rb");
+
+ // verify magic and version
+ {
+ uint32_t magic = fin.read_u32();
+ if (magic != LLAMA_FILE_MAGIC_GGLA) {
+ LLAMA_LOG_ERROR("%s: bad file magic\n", __func__);
+ return 1;
+ }
+
+ uint32_t format_version = fin.read_u32();
+ if (format_version != 1) {
+ LLAMA_LOG_ERROR("%s: unsupported file version\n", __func__ );
+ return 1;
+ }
+ }
+
+ int32_t lora_r = fin.read_u32();
+ int32_t lora_alpha = fin.read_u32();
+ float scaling = scale * (float)lora_alpha / (float)lora_r;
+
+ LLAMA_LOG_INFO("%s: r = %d, alpha = %d, scaling = %.2f\n", __func__, lora_r, lora_alpha, scaling);
+
+ // load base model
+ std::unique_ptr<llama_model_loader> ml;
+ if (path_base_model) {
+ LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
+ ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*check_tensors*/ false, /*kv_overrides*/ nullptr));
+ ml->init_mappings(/*prefetch*/ false); // no prefetching
+ }
+
+ struct tensor_meta {
+ std::string name;
+ ggml_type type;
+ int32_t ne[2];
+ size_t offset;
+ };
+ std::map<std::string, tensor_meta> tensor_meta_map;
+
+ // load all tensor meta
+ while (true) {
+ if (fin.tell() == fin.size) {
+ // eof
+ break;
+ }
+
+ int32_t n_dims;
+ int32_t name_len;
+ int32_t ftype;
+
+ fin.read_raw(&n_dims, sizeof(n_dims));
+ fin.read_raw(&name_len, sizeof(name_len));
+ fin.read_raw(&ftype, sizeof(ftype));
+
+ if (n_dims != 1 && n_dims != 2) {
+ LLAMA_LOG_ERROR("%s: unsupported tensor dimension %d\n", __func__, n_dims);
+ return 1;
+ }
+
+ int32_t ne[2] = { 1, 1 };
+ for (int i = 0; i < n_dims; ++i) {
+ fin.read_raw(&ne[i], sizeof(ne[i]));
+ }
+
+ std::string name;
+ {
+ GGML_ASSERT(name_len < GGML_MAX_NAME);
+ char buf[GGML_MAX_NAME];
+ fin.read_raw(buf, name_len);
+ name = std::string(buf, name_len);
+ }
+
+ // check for lora suffix
+ std::string lora_suffix;
+ if (name.length() > 6) {
+ lora_suffix = name.substr(name.length() - 6);
+ }
+ if (lora_suffix != ".loraA" && lora_suffix != ".loraB") {
+ LLAMA_LOG_ERROR("%s: error: '%s' is not a lora tensor\n", __func__, name.c_str());
+ return 1;
+ }
+
+ // tensor type
+ ggml_type wtype;
+ switch (ftype) {
+ case 0: wtype = GGML_TYPE_F32; break;
+ case 1: wtype = GGML_TYPE_F16; break;
+ default:
+ {
+ LLAMA_LOG_ERROR("%s: invalid tensor data type '%d'\n",
+ __func__, ftype);
+ return 1;
+ }
+ }
+
+ // data offset
+ size_t offset = fin.tell();
+ offset = (offset + 31) & -32;
+
+ // skip tensor data
+ fin.seek(offset + ggml_row_size(wtype, ne[0]) * ne[1], SEEK_SET);
+
+ tensor_meta_map.emplace(name, tensor_meta{ name, wtype, { ne[0], ne[1] }, offset });
+ }
+
+ bool warned = false;
+ int n_tensors = 0;
+
+ // apply
+ ggml_backend_t backend_cpu = ggml_backend_cpu_init();
+ if (backend_cpu == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize cpu backend\n", __func__);
+ return 1;
+ }
+ ggml_backend_cpu_set_n_threads(backend_cpu, n_threads);
+
+ std::vector<no_init<uint8_t>> read_buf;
+ for (const auto & it : model.tensors_by_name) {
+ const std::string & base_name = it.first;
+ ggml_tensor * model_t = it.second;
+
+ if (tensor_meta_map.find(base_name + ".loraA") == tensor_meta_map.end() ||
+ tensor_meta_map.find(base_name + ".loraB") == tensor_meta_map.end()) {
+ continue;
+ }
+
+ tensor_meta & metaA = tensor_meta_map.at(base_name + ".loraA");
+ tensor_meta & metaB = tensor_meta_map.at(base_name + ".loraB");
+
+ ggml_init_params lora_init_params = {
+ /* .mem_size */ ggml_tensor_overhead()*128 + ggml_graph_overhead(),
+ /* .mem_buffer */ nullptr,
+ /* .no_alloc */ true,
+ };
+ ggml_context * lora_ctx = ggml_init(lora_init_params);
+ if (lora_ctx == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to initialize lora context\n", __func__);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ // create tensors
+ ggml_tensor * loraA = ggml_new_tensor_2d(lora_ctx, metaA.type, metaA.ne[0], metaA.ne[1]);
+ ggml_tensor * loraB = ggml_new_tensor_2d(lora_ctx, metaB.type, metaB.ne[0], metaB.ne[1]);
+ ggml_set_name(loraA, metaA.name.c_str());
+ ggml_set_name(loraB, metaB.name.c_str());
+
+ ggml_tensor * base_t;
+ if (ml) {
+ if (!ml->get_tensor_meta(base_name.c_str())) {
+ LLAMA_LOG_ERROR("%s: error: tensor '%s' not found in base model\n", __func__, base_name.c_str());
+ return 1;
+ }
+ base_t = ggml_dup_tensor(lora_ctx, ml->get_tensor_meta(base_name.c_str()));
+ } else {
+ base_t = ggml_dup_tensor(lora_ctx, model_t);
+ }
+ ggml_set_name(base_t, base_name.c_str());
+
+ // allocate in backend buffer
+ ggml_backend_buffer_t lora_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (lora_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate lora tensors\n", __func__);
+ return 1;
+ }
+
+ // load tensor data
+ auto load_tensor = [&read_buf, &fin](const tensor_meta & tensor_meta, ggml_tensor * tensor) {
+ read_buf.resize(ggml_nbytes(tensor));
+ fin.seek(tensor_meta.offset, SEEK_SET);
+ fin.read_raw(read_buf.data(), ggml_nbytes(tensor));
+ ggml_backend_tensor_set(tensor, read_buf.data(), 0, read_buf.size());
+ };
+ load_tensor(metaA, loraA);
+ load_tensor(metaB, loraB);
+
+ // load base model tensor data
+ if (ml) {
+ ml->load_data_for(base_t);
+ } else {
+ ggml_backend_tensor_copy(model_t, base_t);
+ }
+
+ if (ggml_is_quantized(base_t->type) && !warned) {
+ LLAMA_LOG_WARN("%s: warning: using a lora adapter with a quantized model may result in poor quality, "
+ "use a f16 or f32 base model with --lora-base\n", __func__);
+ warned = true;
+ }
+
+ if (base_t->ne[0] != loraA->ne[1] || base_t->ne[1] != loraB->ne[1]) {
+ LLAMA_LOG_ERROR("%s: incompatible tensor dimensions (%" PRId64 " and %" PRId64 ");"
+ " are you sure that this adapter is for this model?\n", __func__, base_t->ne[0], loraA->ne[1]);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ auto build_lora_graph = [&]() {
+ // w = w + BA*s
+ ggml_tensor * BA = ggml_mul_mat(lora_ctx, loraA, loraB);
+ ggml_set_name(BA, "BA");
+
+ if (scaling != 1.0f) {
+ BA = ggml_scale(lora_ctx, BA, scaling);
+ ggml_set_name(BA, "BA_scaled");
+ }
+
+ ggml_tensor * r;
+ r = ggml_add_inplace(lora_ctx, base_t, BA);
+ ggml_set_name(r, "r_add");
+
+ if (base_t->type != model_t->type) {
+ // convert the result to the model type
+ r = ggml_cast(lora_ctx, r, model_t->type);
+ ggml_set_name(r, "r_cast");
+ }
+
+ return r;
+ };
+
+ ggml_cgraph * gf = ggml_new_graph(lora_ctx);
+ ggml_tensor * r = build_lora_graph();
+ ggml_build_forward_expand(gf, r);
+
+ ggml_backend_buffer_t graph_buf = ggml_backend_alloc_ctx_tensors_from_buft(lora_ctx, ggml_backend_cpu_buffer_type());
+ if (graph_buf == nullptr) {
+ LLAMA_LOG_ERROR("%s: error: failed to allocate graph tensors\n", __func__);
+ ggml_free(lora_ctx);
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_free(backend_cpu);
+ return 1;
+ }
+
+ ggml_backend_graph_compute(backend_cpu, gf);
+
+ ggml_backend_tensor_set(model_t, r->data, 0, ggml_nbytes(r));
+
+#if 0
+ // TODO: use scheduler with fallback to CPU for less copies between CPU and GPU
+ //ggml_backend_sched_t sched = ggml_backend_sched_new(backends.data(), backends.size(), GGML_DEFAULT_GRAPH_SIZE);
+
+ // sched compute
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_init_measure(sched, gf);
+
+ // create the graph again, since the previous one was destroyed by the measure
+ ggml_graph_clear(gf);
+ ggml_build_forward_expand(gf, build_graph());
+ ggml_backend_sched_graph_compute(sched, gf);
+ ggml_backend_sched_free(sched);
+#endif
+
+ ggml_backend_buffer_free(lora_buf);
+ ggml_backend_buffer_free(graph_buf);
+ ggml_free(lora_ctx);
+
+ n_tensors++;
+ if (n_tensors % 4 == 0) {
+ LLAMA_LOG_INFO(".");
+ }
+ }
+
+ ggml_backend_free(backend_cpu);
+
+ const int64_t t_lora_us = ggml_time_us() - t_start_lora_us;
+ LLAMA_LOG_INFO(" done (%.2f ms)\n", t_lora_us / 1000.0);
+
+ return 0;
+}
+
+int32_t llama_model_apply_lora_from_file(const struct llama_model * model, const char * path_lora, float scale, const char * path_base_model, int32_t n_threads) {
+ try {
+ return llama_apply_lora_from_file_internal(*model, path_lora, scale, path_base_model, n_threads);
+ } catch (const std::exception & err) {
+ LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());
+ return 1;
+ }
+}
\ No newline at end of file

View File

@@ -0,0 +1,53 @@
From fa509abf281177eacdc71a2a14432c4e6ed74a47 Mon Sep 17 00:00:00 2001
From: Josh Yan <jyan00017@gmail.com>
Date: Wed, 10 Jul 2024 12:58:31 -0700
Subject: [PATCH] quantize callback
---
llama.cpp | 8 ++++++++
llama.h | 3 +++
2 files changed, 11 insertions(+)
diff --git a/llama.cpp b/llama.cpp
index 61948751..d3126510 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -15586,6 +15586,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const auto tn = LLM_TN(model.arch);
new_ofstream(0);
for (int i = 0; i < ml.n_tensors; ++i) {
+ if (params->quantize_callback){
+ if (!params->quantize_callback(i, params->quantize_callback_data)) {
+ return;
+ }
+ }
+
auto weight = ml.get_weight(i);
struct ggml_tensor * tensor = weight->tensor;
if (weight->idx != cur_split && params->keep_split) {
@@ -16119,6 +16125,8 @@ struct llama_model_quantize_params llama_model_quantize_default_params() {
/*.keep_split =*/ false,
/*.imatrix =*/ nullptr,
/*.kv_overrides =*/ nullptr,
+ /*.quantize_callback =*/ nullptr,
+ /*.quantize_callback_data =*/ nullptr,
};
return result;
diff --git a/llama.h b/llama.h
index da310ffa..3cbe6023 100644
--- a/llama.h
+++ b/llama.h
@@ -337,6 +337,9 @@ extern "C" {
bool keep_split; // quantize to the same number of shards
void * imatrix; // pointer to importance matrix data
void * kv_overrides; // pointer to vector containing overrides
+
+ llama_progress_callback quantize_callback; // callback to report quantization progress
+ void * quantize_callback_data; // user data for the callback
} llama_model_quantize_params;
// grammar types
--
2.39.3 (Apple Git-146)

View File

@@ -33,7 +33,7 @@ type LlamaServer interface {
Ping(ctx context.Context) error Ping(ctx context.Context) error
WaitUntilRunning(ctx context.Context) error WaitUntilRunning(ctx context.Context) error
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
Embed(ctx context.Context, input []string) ([][]float32, error) Embedding(ctx context.Context, prompt string) ([]float64, error)
Tokenize(ctx context.Context, content string) ([]int, error) Tokenize(ctx context.Context, content string) ([]int, error)
Detokenize(ctx context.Context, tokens []int) (string, error) Detokenize(ctx context.Context, tokens []int) (string, error)
Close() error Close() error
@@ -88,7 +88,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var estimate MemoryEstimate var estimate MemoryEstimate
var systemTotalMemory uint64 var systemTotalMemory uint64
var systemFreeMemory uint64 var systemFreeMemory uint64
var systemSwapFreeMemory uint64
systemMemInfo, err := gpu.GetCPUMem() systemMemInfo, err := gpu.GetCPUMem()
if err != nil { if err != nil {
@@ -96,8 +95,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} else { } else {
systemTotalMemory = systemMemInfo.TotalMemory systemTotalMemory = systemMemInfo.TotalMemory
systemFreeMemory = systemMemInfo.FreeMemory systemFreeMemory = systemMemInfo.FreeMemory
systemSwapFreeMemory = systemMemInfo.FreeSwap slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", systemFreeMemory)
slog.Debug("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory))
} }
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
@@ -124,16 +122,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
} }
} }
// On linux, over-allocating CPU memory will almost always result in an error
if runtime.GOOS == "linux" {
systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize
available := systemFreeMemory + systemSwapFreeMemory
if systemMemoryRequired > available {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory))
return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available))
}
}
estimate.log() estimate.log()
// Loop through potential servers // Loop through potential servers
@@ -266,6 +254,10 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--tensor-split", estimate.TensorSplit) params = append(params, "--tensor-split", estimate.TensorSplit)
} }
if estimate.TensorSplit != "" {
params = append(params, "--tensor-split", estimate.TensorSplit)
}
for i := range len(servers) { for i := range len(servers) {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
if dir == "" { if dir == "" {
@@ -346,7 +338,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
s.cmd.Env = os.Environ() s.cmd.Env = os.Environ()
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
s.cmd.SysProcAttr = LlamaServerSysProcAttr
envWorkarounds := [][2]string{} envWorkarounds := [][2]string{}
for _, gpu := range gpus { for _, gpu := range gpus {
@@ -386,10 +377,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
filteredEnv := []string{} filteredEnv := []string{}
for _, ev := range s.cmd.Env { for _, ev := range s.cmd.Env {
if strings.HasPrefix(ev, "CUDA_") || if strings.HasPrefix(ev, "CUDA_") ||
strings.HasPrefix(ev, "ROCR_") ||
strings.HasPrefix(ev, "ROCM_") || strings.HasPrefix(ev, "ROCM_") ||
strings.HasPrefix(ev, "HIP_") || strings.HasPrefix(ev, "HIP_") ||
strings.HasPrefix(ev, "GPU_") ||
strings.HasPrefix(ev, "HSA_") || strings.HasPrefix(ev, "HSA_") ||
strings.HasPrefix(ev, "GGML_") || strings.HasPrefix(ev, "GGML_") ||
strings.HasPrefix(ev, "PATH=") || strings.HasPrefix(ev, "PATH=") ||
@@ -418,17 +407,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
// reap subprocess when it exits // reap subprocess when it exits
go func() { go func() {
err := s.cmd.Wait() s.done <- s.cmd.Wait()
// Favor a more detailed message over the process exit status
if err != nil && s.status != nil && s.status.LastErrMsg != "" {
slog.Debug("llama runner terminated", "error", err)
if strings.Contains(s.status.LastErrMsg, "unknown model") {
s.status.LastErrMsg = "this model is not supported by your version of Ollama. You may need to upgrade"
}
s.done <- fmt.Errorf(s.status.LastErrMsg)
} else {
s.done <- err
}
}() }()
return s, nil return s, nil
@@ -591,7 +570,14 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error {
slog.Warn("client connection closed before server finished loading, aborting load") slog.Warn("client connection closed before server finished loading, aborting load")
return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err()) return fmt.Errorf("timed out waiting for llama runner to start: %w", ctx.Err())
case err := <-s.done: case err := <-s.done:
return fmt.Errorf("llama runner process has terminated: %w", err) msg := ""
if s.status != nil && s.status.LastErrMsg != "" {
msg = s.status.LastErrMsg
}
if strings.Contains(msg, "unknown model") {
return fmt.Errorf("this model is not supported by your version of Ollama. You may need to upgrade")
}
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
default: default:
} }
if time.Now().After(stallTimer) { if time.Now().After(stallTimer) {
@@ -727,7 +713,6 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
"temperature": req.Options.Temperature, "temperature": req.Options.Temperature,
"top_k": req.Options.TopK, "top_k": req.Options.TopK,
"top_p": req.Options.TopP, "top_p": req.Options.TopP,
"min_p": req.Options.MinP,
"tfs_z": req.Options.TFSZ, "tfs_z": req.Options.TFSZ,
"typical_p": req.Options.TypicalP, "typical_p": req.Options.TypicalP,
"repeat_last_n": req.Options.RepeatLastN, "repeat_last_n": req.Options.RepeatLastN,
@@ -874,15 +859,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
return nil return nil
} }
type EmbedRequest struct { type EmbeddingRequest struct {
Content []string `json:"content"` Content string `json:"content"`
} }
type EmbedResponse struct { type EmbeddingResponse struct {
Embedding [][]float32 `json:"embedding"` Embedding []float64 `json:"embedding"`
} }
func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
if err := s.sem.Acquire(ctx, 1); err != nil { if err := s.sem.Acquire(ctx, 1); err != nil {
slog.Error("Failed to acquire semaphore", "error", err) slog.Error("Failed to acquire semaphore", "error", err)
return nil, err return nil, err
@@ -897,7 +882,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("unexpected server status: %s", status.ToString()) return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
} }
data, err := json.Marshal(EmbedRequest{Content: input}) data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil { if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err) return nil, fmt.Errorf("error marshaling embed data: %w", err)
} }
@@ -924,7 +909,7 @@ func (s *llmServer) Embed(ctx context.Context, input []string) ([][]float32, err
return nil, fmt.Errorf("%s", body) return nil, fmt.Errorf("%s", body)
} }
var embedding EmbedResponse var embedding EmbeddingResponse
if err := json.Unmarshal(body, &embedding); err != nil { if err := json.Unmarshal(body, &embedding); err != nil {
return nil, fmt.Errorf("unmarshal tokenize response: %w", err) return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
} }

View File

@@ -19,7 +19,7 @@ export default function () {
const [step, setStep] = useState<Step>(Step.WELCOME) const [step, setStep] = useState<Step>(Step.WELCOME)
const [commandCopied, setCommandCopied] = useState<boolean>(false) const [commandCopied, setCommandCopied] = useState<boolean>(false)
const command = 'ollama run llama3.1' const command = 'ollama run llama3'
return ( return (
<div className='drag'> <div className='drag'>

View File

@@ -3,14 +3,11 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"math/rand" "math/rand"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -30,9 +27,8 @@ type ErrorResponse struct {
} }
type Message struct { type Message struct {
Role string `json:"role"` Role string `json:"role"`
Content any `json:"content"` Content string `json:"content"`
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
} }
type Choice struct { type Choice struct {
@@ -63,11 +59,6 @@ type ResponseFormat struct {
Type string `json:"type"` Type string `json:"type"`
} }
type EmbedRequest struct {
Input any `json:"input"`
Model string `json:"model"`
}
type ChatCompletionRequest struct { type ChatCompletionRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
@@ -80,7 +71,6 @@ type ChatCompletionRequest struct {
PresencePenalty *float64 `json:"presence_penalty_penalty"` PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"` TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"` ResponseFormat *ResponseFormat `json:"response_format"`
Tools []api.Tool `json:"tools"`
} }
type ChatCompletion struct { type ChatCompletion struct {
@@ -114,7 +104,6 @@ type CompletionRequest struct {
Stream bool `json:"stream"` Stream bool `json:"stream"`
Temperature *float32 `json:"temperature"` Temperature *float32 `json:"temperature"`
TopP float32 `json:"top_p"` TopP float32 `json:"top_p"`
Suffix string `json:"suffix"`
} }
type Completion struct { type Completion struct {
@@ -136,15 +125,6 @@ type CompletionChunk struct {
SystemFingerprint string `json:"system_fingerprint"` SystemFingerprint string `json:"system_fingerprint"`
} }
type ToolCall struct {
ID string `json:"id"`
Type string `json:"type"`
Function struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
} `json:"function"`
}
type Model struct { type Model struct {
Id string `json:"id"` Id string `json:"id"`
Object string `json:"object"` Object string `json:"object"`
@@ -152,23 +132,11 @@ type Model struct {
OwnedBy string `json:"owned_by"` OwnedBy string `json:"owned_by"`
} }
type Embedding struct {
Object string `json:"object"`
Embedding []float32 `json:"embedding"`
Index int `json:"index"`
}
type ListCompletion struct { type ListCompletion struct {
Object string `json:"object"` Object string `json:"object"`
Data []Model `json:"data"` Data []Model `json:"data"`
} }
type EmbeddingList struct {
Object string `json:"object"`
Data []Embedding `json:"data"`
Model string `json:"model"`
}
func NewError(code int, message string) ErrorResponse { func NewError(code int, message string) ErrorResponse {
var etype string var etype string
switch code { switch code {
@@ -183,36 +151,7 @@ func NewError(code int, message string) ErrorResponse {
return ErrorResponse{Error{Type: etype, Message: message}} return ErrorResponse{Error{Type: etype, Message: message}}
} }
func toolCallId() string {
const letterBytes = "abcdefghijklmnopqrstuvwxyz0123456789"
b := make([]byte, 8)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return "call_" + strings.ToLower(string(b))
}
func parseToolCalls(respToolCalls []api.ToolCall) []ToolCall {
toolCalls := make([]ToolCall, len(respToolCalls))
for i, tc := range respToolCalls {
toolCalls[i].ID = toolCallId()
toolCalls[i].Type = "function"
toolCalls[i].Function.Name = tc.Function.Name
args, err := json.Marshal(tc.Function.Arguments)
if err != nil {
slog.Error("could not marshall function arguments to json", "error", err)
continue
}
toolCalls[i].Function.Arguments = string(args)
}
return toolCalls
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion { func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
toolCalls := parseToolCalls(r.Message.ToolCalls)
return ChatCompletion{ return ChatCompletion{
Id: id, Id: id,
Object: "chat.completion", Object: "chat.completion",
@@ -221,7 +160,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []Choice{{ Choices: []Choice{{
Index: 0, Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content, ToolCalls: toolCalls}, Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@@ -230,6 +169,7 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -238,8 +178,6 @@ func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
} }
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk { func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
toolCalls := parseToolCalls(r.Message.ToolCalls)
return ChatCompletionChunk{ return ChatCompletionChunk{
Id: id, Id: id,
Object: "chat.completion.chunk", Object: "chat.completion.chunk",
@@ -248,7 +186,7 @@ func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
SystemFingerprint: "fp_ollama", SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{{ Choices: []ChunkChoice{{
Index: 0, Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content, ToolCalls: toolCalls}, Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: func(reason string) *string { FinishReason: func(reason string) *string {
if len(reason) > 0 { if len(reason) > 0 {
return &reason return &reason
@@ -277,6 +215,7 @@ func toCompletion(id string, r api.GenerateResponse) Completion {
}(r.DoneReason), }(r.DoneReason),
}}, }},
Usage: Usage{ Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount, PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount, CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount, TotalTokens: r.PromptEvalCount + r.EvalCount,
@@ -321,27 +260,6 @@ func toListCompletion(r api.ListResponse) ListCompletion {
} }
} }
func toEmbeddingList(model string, r api.EmbedResponse) EmbeddingList {
if r.Embeddings != nil {
var data []Embedding
for i, e := range r.Embeddings {
data = append(data, Embedding{
Object: "embedding",
Embedding: e,
Index: i,
})
}
return EmbeddingList{
Object: "list",
Data: data,
Model: model,
}
}
return EmbeddingList{}
}
func toModel(r api.ShowResponse, m string) Model { func toModel(r api.ShowResponse, m string) Model {
return Model{ return Model{
Id: m, Id: m,
@@ -351,77 +269,10 @@ func toModel(r api.ShowResponse, m string) Model {
} }
} }
func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) { func fromChatRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message var messages []api.Message
for _, msg := range r.Messages { for _, msg := range r.Messages {
switch content := msg.Content.(type) { messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
case string:
messages = append(messages, api.Message{Role: msg.Role, Content: content})
case []any:
for _, c := range content {
data, ok := c.(map[string]any)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
switch data["type"] {
case "text":
text, ok := data["text"].(string)
if !ok {
return nil, fmt.Errorf("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Content: text})
case "image_url":
var url string
if urlMap, ok := data["image_url"].(map[string]any); ok {
if url, ok = urlMap["url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
} else {
if url, ok = data["image_url"].(string); !ok {
return nil, fmt.Errorf("invalid message format")
}
}
types := []string{"jpeg", "jpg", "png"}
valid := false
for _, t := range types {
prefix := "data:image/" + t + ";base64,"
if strings.HasPrefix(url, prefix) {
url = strings.TrimPrefix(url, prefix)
valid = true
break
}
}
if !valid {
return nil, fmt.Errorf("invalid image input")
}
img, err := base64.StdEncoding.DecodeString(url)
if err != nil {
return nil, fmt.Errorf("invalid message format")
}
messages = append(messages, api.Message{Role: msg.Role, Images: []api.ImageData{img}})
default:
return nil, fmt.Errorf("invalid message format")
}
}
default:
if msg.ToolCalls == nil {
return nil, fmt.Errorf("invalid message content type: %T", content)
}
toolCalls := make([]api.ToolCall, len(msg.ToolCalls))
for i, tc := range msg.ToolCalls {
toolCalls[i].Function.Name = tc.Function.Name
err := json.Unmarshal([]byte(tc.Function.Arguments), &toolCalls[i].Function.Arguments)
if err != nil {
return nil, fmt.Errorf("invalid tool call arguments")
}
}
messages = append(messages, api.Message{Role: msg.Role, ToolCalls: toolCalls})
}
} }
options := make(map[string]interface{}) options := make(map[string]interface{})
@@ -472,14 +323,13 @@ func fromChatRequest(r ChatCompletionRequest) (*api.ChatRequest, error) {
format = "json" format = "json"
} }
return &api.ChatRequest{ return api.ChatRequest{
Model: r.Model, Model: r.Model,
Messages: messages, Messages: messages,
Format: format, Format: format,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Tools: r.Tools, }
}, nil
} }
func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) { func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
@@ -529,7 +379,6 @@ func fromCompleteRequest(r CompletionRequest) (api.GenerateRequest, error) {
Prompt: r.Prompt, Prompt: r.Prompt,
Options: options, Options: options,
Stream: &r.Stream, Stream: &r.Stream,
Suffix: r.Suffix,
}, nil }, nil
} }
@@ -558,11 +407,6 @@ type RetrieveWriter struct {
model string model string
} }
type EmbedWriter struct {
BaseWriter
model string
}
func (w *BaseWriter) writeError(code int, data []byte) (int, error) { func (w *BaseWriter) writeError(code int, data []byte) (int, error) {
var serr api.StatusError var serr api.StatusError
err := json.Unmarshal(data, &serr) err := json.Unmarshal(data, &serr)
@@ -728,33 +572,6 @@ func (w *RetrieveWriter) Write(data []byte) (int, error) {
return w.writeResponse(data) return w.writeResponse(data)
} }
func (w *EmbedWriter) writeResponse(data []byte) (int, error) {
var embedResponse api.EmbedResponse
err := json.Unmarshal(data, &embedResponse)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toEmbeddingList(w.model, embedResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *EmbedWriter) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func ListMiddleware() gin.HandlerFunc { func ListMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
w := &ListWriter{ w := &ListWriter{
@@ -818,47 +635,6 @@ func CompletionsMiddleware() gin.HandlerFunc {
id: fmt.Sprintf("cmpl-%d", rand.Intn(999)), id: fmt.Sprintf("cmpl-%d", rand.Intn(999)),
} }
c.Writer = w
c.Next()
}
}
func EmbeddingsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req EmbedRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if req.Input == "" {
req.Input = []string{""}
}
if req.Input == nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
if v, ok := req.Input.([]any); ok && len(v) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "invalid input"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(api.EmbedRequest{Model: req.Model, Input: req.Input}); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &EmbedWriter{
BaseWriter: BaseWriter{ResponseWriter: c.Writer},
model: req.Model,
}
c.Writer = w c.Writer = w
c.Next() c.Next()
@@ -880,14 +656,7 @@ func ChatMiddleware() gin.HandlerFunc {
} }
var b bytes.Buffer var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromChatRequest(req)); err != nil {
chatReq, err := fromChatRequest(req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if err := json.NewEncoder(&b).Encode(chatReq); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error())) c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return return
} }

View File

@@ -2,7 +2,6 @@ package openai
import ( import (
"bytes" "bytes"
"encoding/base64"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@@ -16,199 +15,64 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
const prefix = `data:image/jpeg;base64,` func TestMiddlewareRequests(t *testing.T) {
const image = `iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk+A8AAQUBAScY42YAAAAASUVORK5CYII=`
const imageURL = prefix + image
func prepareRequest(req *http.Request, body any) {
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}
func captureRequestMiddleware(capturedRequest any) gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
err := json.Unmarshal(bodyBytes, capturedRequest)
if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, "failed to unmarshal request")
}
c.Next()
}
}
func TestChatMiddleware(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
Method string
Path string
Handler func() gin.HandlerFunc
Setup func(t *testing.T, req *http.Request) Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) Expected func(t *testing.T, req *http.Request)
} }
var capturedRequest *api.ChatRequest var capturedRequest *http.Request
captureRequestMiddleware := func() gin.HandlerFunc {
return func(c *gin.Context) {
bodyBytes, _ := io.ReadAll(c.Request.Body)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
capturedRequest = c.Request
c.Next()
}
}
testCases := []testCase{ testCases := []testCase{
{ {
Name: "chat handler", Name: "chat handler",
Method: http.MethodPost,
Path: "/api/chat",
Handler: ChatMiddleware,
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{ body := ChatCompletionRequest{
Model: "test-model", Model: "test-model",
Messages: []Message{{Role: "user", Content: "Hello"}}, Messages: []Message{{Role: "user", Content: "Hello"}},
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, req *http.Request) {
if resp.Code != http.StatusOK { var chatReq api.ChatRequest
t.Fatalf("expected 200, got %d", resp.Code) if err := json.NewDecoder(req.Body).Decode(&chatReq); err != nil {
t.Fatal(err)
} }
if req.Messages[0].Role != "user" { if chatReq.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role) t.Fatalf("expected 'user', got %s", chatReq.Messages[0].Role)
} }
if req.Messages[0].Content != "Hello" { if chatReq.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content) t.Fatalf("expected 'Hello', got %s", chatReq.Messages[0].Content)
} }
}, },
}, },
{ {
Name: "chat handler with image content", Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) { Method: http.MethodPost,
body := ChatCompletionRequest{ Path: "/api/generate",
Model: "test-model", Handler: CompletionsMiddleware,
Messages: []Message{
{
Role: "user", Content: []map[string]any{
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": map[string]string{"url": imageURL}},
},
},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusOK {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[0].Role)
}
if req.Messages[0].Content != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Messages[0].Content)
}
img, _ := base64.StdEncoding.DecodeString(imageURL[len(prefix):])
if req.Messages[1].Role != "user" {
t.Fatalf("expected 'user', got %s", req.Messages[1].Role)
}
if !bytes.Equal(req.Messages[1].Images[0], img) {
t.Fatalf("expected image encoding, got %s", req.Messages[1].Images[0])
}
},
},
{
Name: "chat handler with tools",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{
{Role: "user", Content: "What's the weather like in Paris Today?"},
{Role: "assistant", ToolCalls: []ToolCall{{
ID: "id",
Type: "function",
Function: struct {
Name string `json:"name"`
Arguments string `json:"arguments"`
}{
Name: "get_current_weather",
Arguments: "{\"location\": \"Paris, France\", \"format\": \"celsius\"}",
},
}}},
},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != 200 {
t.Fatalf("expected 200, got %d", resp.Code)
}
if req.Messages[0].Content != "What's the weather like in Paris Today?" {
t.Fatalf("expected What's the weather like in Paris Today?, got %s", req.Messages[0].Content)
}
if req.Messages[1].ToolCalls[0].Function.Arguments["location"] != "Paris, France" {
t.Fatalf("expected 'Paris, France', got %v", req.Messages[1].ToolCalls[0].Function.Arguments["location"])
}
if req.Messages[1].ToolCalls[0].Function.Arguments["format"] != "celsius" {
t.Fatalf("expected celsius, got %v", req.Messages[1].ToolCalls[0].Function.Arguments["format"])
}
},
},
{
Name: "chat handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := ChatCompletionRequest{
Model: "test-model",
Messages: []Message{{Role: "user", Content: 2}},
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.ChatRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid message content type") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(ChatMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/chat", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/chat", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
})
}
}
func TestCompletionsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.GenerateRequest
testCases := []testCase{
{
Name: "completions handler",
Setup: func(t *testing.T, req *http.Request) { Setup: func(t *testing.T, req *http.Request) {
temp := float32(0.8) temp := float32(0.8)
body := CompletionRequest{ body := CompletionRequest{
@@ -216,20 +80,28 @@ func TestCompletionsMiddleware(t *testing.T) {
Prompt: "Hello", Prompt: "Hello",
Temperature: &temp, Temperature: &temp,
Stop: []string{"\n", "stop"}, Stop: []string{"\n", "stop"},
Suffix: "suffix",
} }
prepareRequest(req, body)
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
}, },
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, req *http.Request) {
if req.Prompt != "Hello" { var genReq api.GenerateRequest
t.Fatalf("expected 'Hello', got %s", req.Prompt) if err := json.NewDecoder(req.Body).Decode(&genReq); err != nil {
t.Fatal(err)
} }
if req.Options["temperature"] != 1.6 { if genReq.Prompt != "Hello" {
t.Fatalf("expected 1.6, got %f", req.Options["temperature"]) t.Fatalf("expected 'Hello', got %s", genReq.Prompt)
} }
stopTokens, ok := req.Options["stop"].([]any) if genReq.Options["temperature"] != 1.6 {
t.Fatalf("expected 1.6, got %f", genReq.Options["temperature"])
}
stopTokens, ok := genReq.Options["stop"].([]any)
if !ok { if !ok {
t.Fatalf("expected stop tokens to be a list") t.Fatalf("expected stop tokens to be a list")
@@ -238,160 +110,33 @@ func TestCompletionsMiddleware(t *testing.T) {
if stopTokens[0] != "\n" || stopTokens[1] != "stop" { if stopTokens[0] != "\n" || stopTokens[1] != "stop" {
t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens) t.Fatalf("expected ['\\n', 'stop'], got %v", stopTokens)
} }
if req.Suffix != "suffix" {
t.Fatalf("expected 'suffix', got %s", req.Suffix)
}
},
},
{
Name: "completions handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
Temperature: nil,
Stop: []int{1, 2},
Suffix: "suffix",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.GenerateRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid type for 'stop' field") {
t.Fatalf("error was not forwarded")
}
}, },
}, },
} }
gin.SetMode(gin.TestMode)
router := gin.New()
endpoint := func(c *gin.Context) { endpoint := func(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
} }
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(CompletionsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/generate", endpoint)
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/generate", nil) router = gin.New()
router.Use(captureRequestMiddleware())
router.Use(tc.Handler())
router.Handle(tc.Method, tc.Path, endpoint)
req, _ := http.NewRequest(tc.Method, tc.Path, nil)
tc.Setup(t, req) if tc.Setup != nil {
tc.Setup(t, req)
}
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp) tc.Expected(t, capturedRequest)
capturedRequest = nil
})
}
}
func TestEmbeddingsMiddleware(t *testing.T) {
type testCase struct {
Name string
Setup func(t *testing.T, req *http.Request)
Expected func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder)
}
var capturedRequest *api.EmbedRequest
testCases := []testCase{
{
Name: "embed handler single input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: "Hello",
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if req.Input != "Hello" {
t.Fatalf("expected 'Hello', got %s", req.Input)
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler batch input",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Input: []string{"Hello", "World"},
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
input, ok := req.Input.([]any)
if !ok {
t.Fatalf("expected input to be a list")
}
if input[0].(string) != "Hello" {
t.Fatalf("expected 'Hello', got %s", input[0])
}
if input[1].(string) != "World" {
t.Fatalf("expected 'World', got %s", input[1])
}
if req.Model != "test-model" {
t.Fatalf("expected 'test-model', got %s", req.Model)
}
},
},
{
Name: "embed handler error forwarding",
Setup: func(t *testing.T, req *http.Request) {
body := EmbedRequest{
Model: "test-model",
}
prepareRequest(req, body)
},
Expected: func(t *testing.T, req *api.EmbedRequest, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), "invalid input") {
t.Fatalf("error was not forwarded")
}
},
},
}
endpoint := func(c *gin.Context) {
c.Status(http.StatusOK)
}
gin.SetMode(gin.TestMode)
router := gin.New()
router.Use(EmbeddingsMiddleware(), captureRequestMiddleware(&capturedRequest))
router.Handle(http.MethodPost, "/api/embed", endpoint)
for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "/api/embed", nil)
tc.Setup(t, req)
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
tc.Expected(t, capturedRequest, resp)
capturedRequest = nil
}) })
} }
} }
@@ -409,6 +154,36 @@ func TestMiddlewareResponses(t *testing.T) {
} }
testCases := []testCase{ testCases := []testCase{
{
Name: "completions handler error forwarding",
Method: http.MethodPost,
Path: "/api/generate",
TestPath: "/api/generate",
Handler: CompletionsMiddleware,
Endpoint: func(c *gin.Context) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request"})
},
Setup: func(t *testing.T, req *http.Request) {
body := CompletionRequest{
Model: "test-model",
Prompt: "Hello",
}
bodyBytes, _ := json.Marshal(body)
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
req.Header.Set("Content-Type", "application/json")
},
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
if resp.Code != http.StatusBadRequest {
t.Fatalf("expected 400, got %d", resp.Code)
}
if !strings.Contains(resp.Body.String(), `"invalid request"`) {
t.Fatalf("error was not forwarded")
}
},
},
{ {
Name: "list handler", Name: "list handler",
Method: http.MethodGet, Method: http.MethodGet,
@@ -425,6 +200,8 @@ func TestMiddlewareResponses(t *testing.T) {
}) })
}, },
Expected: func(t *testing.T, resp *httptest.ResponseRecorder) { Expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, http.StatusOK, resp.Code)
var listResp ListCompletion var listResp ListCompletion
if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil { if err := json.NewDecoder(resp.Body).Decode(&listResp); err != nil {
t.Fatal(err) t.Fatal(err)
@@ -488,8 +265,6 @@ func TestMiddlewareResponses(t *testing.T) {
resp := httptest.NewRecorder() resp := httptest.NewRecorder()
router.ServeHTTP(resp, req) router.ServeHTTP(resp, req)
assert.Equal(t, http.StatusOK, resp.Code)
tc.Expected(t, resp) tc.Expected(t, resp)
}) })
} }

View File

@@ -451,7 +451,6 @@ func TestParseFileParameters(t *testing.T) {
"num_predict 1": {"num_predict", "1"}, "num_predict 1": {"num_predict", "1"},
"top_k 1": {"top_k", "1"}, "top_k 1": {"top_k", "1"},
"top_p 1.0": {"top_p", "1.0"}, "top_p 1.0": {"top_p", "1.0"},
"min_p 0.05": {"min_p", "0.05"},
"tfs_z 1.0": {"tfs_z", "1.0"}, "tfs_z 1.0": {"tfs_z", "1.0"},
"typical_p 1.0": {"typical_p", "1.0"}, "typical_p 1.0": {"typical_p", "1.0"},
"repeat_last_n 1": {"repeat_last_n", "1"}, "repeat_last_n 1": {"repeat_last_n", "1"},

View File

@@ -31,6 +31,10 @@ func NewSpinner(message string) *Spinner {
return s return s
} }
func (s *Spinner) SetMessage(message string) {
s.message = message
}
func (s *Spinner) String() string { func (s *Spinner) String() string {
var sb strings.Builder var sb strings.Builder
if len(s.message) > 0 { if len(s.message) > 0 {

View File

@@ -107,12 +107,9 @@ function gatherDependencies() {
# TODO - this varies based on host build system and MSVC version - drive from dumpbin output # TODO - this varies based on host build system and MSVC version - drive from dumpbin output
# currently works for Win11 + MSVC 2019 + Cuda V11 # currently works for Win11 + MSVC 2019 + Cuda V11
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140*.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\msvcp140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140.dll" "${script:DEPS_DIR}\ollama_runners\"
cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\" cp "${env:VCToolsRedistDir}\x64\Microsoft.VC*.CRT\vcruntime140_1.dll" "${script:DEPS_DIR}\ollama_runners\"
foreach ($part in $("runtime", "stdio", "filesystem", "math", "convert", "heap", "string", "time", "locale", "environment")) {
cp "$env:VCToolsRedistDir\..\..\..\Tools\Llvm\x64\bin\api-ms-win-crt-${part}*.dll" "${script:DEPS_DIR}\ollama_runners\"
}
cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\" cp "${script:SRC_DIR}\app\ollama_welcome.ps1" "${script:SRC_DIR}\dist\"

View File

@@ -198,29 +198,19 @@ if check_gpu lspci amdgpu || check_gpu lshw amdgpu; then
exit 0 exit 0
fi fi
CUDA_REPO_ERR_MSG="NVIDIA GPU detected, but your OS and Architecture are not supported by NVIDIA. Please install the CUDA driver manually https://docs.nvidia.com/cuda/cuda-installation-guide-linux/"
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7 # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-7-centos-7
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8 # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-8-rocky-8
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9 # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#rhel-9-rocky-9
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#fedora
install_cuda_driver_yum() { install_cuda_driver_yum() {
status 'Installing NVIDIA repository...' status 'Installing NVIDIA repository...'
case $PACKAGE_MANAGER in case $PACKAGE_MANAGER in
yum) yum)
$SUDO $PACKAGE_MANAGER -y install yum-utils $SUDO $PACKAGE_MANAGER -y install yum-utils
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then $SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
$SUDO $PACKAGE_MANAGER-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;; ;;
dnf) dnf)
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo" >/dev/null ; then $SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
$SUDO $PACKAGE_MANAGER config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-$1$2.repo
else
error $CUDA_REPO_ERR_MSG
fi
;; ;;
esac esac
@@ -245,11 +235,7 @@ install_cuda_driver_yum() {
# ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian # ref: https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html#debian
install_cuda_driver_apt() { install_cuda_driver_apt() {
status 'Installing NVIDIA repository...' status 'Installing NVIDIA repository...'
if curl -I --silent --fail --location "https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb" >/dev/null ; then curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
curl -fsSL -o $TEMP_DIR/cuda-keyring.deb https://developer.download.nvidia.com/compute/cuda/repos/$1$2/$(uname -m)/cuda-keyring_1.1-1_all.deb
else
error $CUDA_REPO_ERR_MSG
fi
case $1 in case $1 in
debian) debian)

View File

@@ -67,7 +67,7 @@ func getAuthorizationToken(ctx context.Context, challenge registryChallenge) (st
headers.Add("Authorization", signature) headers.Add("Authorization", signature)
response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, &registryOptions{}) response, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -8,7 +8,6 @@ import (
"io" "io"
"log/slog" "log/slog"
"math" "math"
"math/rand/v2"
"net/http" "net/http"
"net/url" "net/url"
"os" "os"
@@ -44,19 +43,17 @@ type blobDownload struct {
context.CancelFunc context.CancelFunc
done chan struct{} done bool
err error err error
references atomic.Int32 references atomic.Int32
} }
type blobDownloadPart struct { type blobDownloadPart struct {
N int N int
Offset int64 Offset int64
Size int64 Size int64
Completed atomic.Int64 Completed int64
lastUpdated time.Time
lastUpdatedMu sync.Mutex
lastUpdated time.Time
*blobDownload `json:"-"` *blobDownload `json:"-"`
} }
@@ -74,7 +71,7 @@ func (p *blobDownloadPart) Name() string {
} }
func (p *blobDownloadPart) StartsAt() int64 { func (p *blobDownloadPart) StartsAt() int64 {
return p.Offset + p.Completed.Load() return p.Offset + p.Completed
} }
func (p *blobDownloadPart) StopsAt() int64 { func (p *blobDownloadPart) StopsAt() int64 {
@@ -84,9 +81,7 @@ func (p *blobDownloadPart) StopsAt() int64 {
func (p *blobDownloadPart) Write(b []byte) (n int, err error) { func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b) n = len(b)
p.blobDownload.Completed.Add(int64(n)) p.blobDownload.Completed.Add(int64(n))
p.lastUpdatedMu.Lock()
p.lastUpdated = time.Now() p.lastUpdated = time.Now()
p.lastUpdatedMu.Unlock()
return n, nil return n, nil
} }
@@ -96,8 +91,6 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
return err return err
} }
b.done = make(chan struct{})
for _, partFilePath := range partFilePaths { for _, partFilePath := range partFilePaths {
part, err := b.readPart(partFilePath) part, err := b.readPart(partFilePath)
if err != nil { if err != nil {
@@ -105,7 +98,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
} }
b.Total += part.Size b.Total += part.Size
b.Completed.Add(part.Completed.Load()) b.Completed.Add(part.Completed)
b.Parts = append(b.Parts, part) b.Parts = append(b.Parts, part)
} }
@@ -145,36 +138,9 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *r
} }
func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) { func (b *blobDownload) Run(ctx context.Context, requestURL *url.URL, opts *registryOptions) {
defer close(b.done)
b.err = b.run(ctx, requestURL, opts) b.err = b.run(ctx, requestURL, opts)
} }
func newBackoff(maxBackoff time.Duration) func(ctx context.Context) error {
var n int
return func(ctx context.Context) error {
if ctx.Err() != nil {
return ctx.Err()
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := min(time.Duration(n*n)*10*time.Millisecond, maxBackoff)
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
defer t.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-t.C:
return nil
}
}
}
func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error { func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *registryOptions) error {
defer blobDownloadManager.Delete(b.Digest) defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx) ctx, b.CancelFunc = context.WithCancel(ctx)
@@ -187,57 +153,11 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
_ = file.Truncate(b.Total) _ = file.Truncate(b.Total)
directURL, err := func() (*url.URL, error) {
ctx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
backoff := newBackoff(10 * time.Second)
for {
// shallow clone opts to be used in the closure
// without affecting the outer opts.
newOpts := new(registryOptions)
*newOpts = *opts
newOpts.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) > 10 {
return errors.New("maxium redirects exceeded (10) for directURL")
}
// if the hostname is the same, allow the redirect
if req.URL.Hostname() == requestURL.Hostname() {
return nil
}
// stop at the first redirect that is not
// the same hostname as the original
// request.
return http.ErrUseLastResponse
}
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, nil, nil, newOpts)
if err != nil {
slog.Warn("failed to get direct URL; backing off and retrying", "err", err)
if err := backoff(ctx); err != nil {
return nil, err
}
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusTemporaryRedirect {
return nil, fmt.Errorf("unexpected status code %d", resp.StatusCode)
}
return resp.Location()
}
}()
if err != nil {
return err
}
g, inner := errgroup.WithContext(ctx) g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts) g.SetLimit(numDownloadParts)
for i := range b.Parts { for i := range b.Parts {
part := b.Parts[i] part := b.Parts[i]
if part.Completed.Load() == part.Size { if part.Completed == part.Size {
continue continue
} }
@@ -245,7 +165,7 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
var err error var err error
for try := 0; try < maxRetries; try++ { for try := 0; try < maxRetries; try++ {
w := io.NewOffsetWriter(file, part.StartsAt()) w := io.NewOffsetWriter(file, part.StartsAt())
err = b.downloadChunk(inner, directURL, w, part) err = b.downloadChunk(inner, requestURL, w, part, opts)
switch { switch {
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
@@ -286,31 +206,29 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *regis
return err return err
} }
b.done = true
return nil return nil
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *registryOptions) error {
g, ctx := errgroup.WithContext(ctx) g, ctx := errgroup.WithContext(ctx)
g.Go(func() error { g.Go(func() error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, requestURL.String(), nil) headers := make(http.Header)
if err != nil { headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
return err resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
}
req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
return err return err
} }
defer resp.Body.Close() defer resp.Body.Close()
n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed.Load()) n, err := io.CopyN(w, io.TeeReader(resp.Body, part), part.Size-part.Completed)
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
return err return err
} }
part.Completed.Add(n) part.Completed += n
if err := b.writePart(part.Name(), part); err != nil { if err := b.writePart(part.Name(), part); err != nil {
return err return err
} }
@@ -324,21 +242,15 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
if part.Completed.Load() >= part.Size { if part.Completed >= part.Size {
return nil return nil
} }
part.lastUpdatedMu.Lock() if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
lastUpdated := part.lastUpdated
part.lastUpdatedMu.Unlock()
if !lastUpdated.IsZero() && time.Since(lastUpdated) > 5*time.Second {
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection." const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N)) slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
// reset last updated // reset last updated
part.lastUpdatedMu.Lock()
part.lastUpdated = time.Time{} part.lastUpdated = time.Time{}
part.lastUpdatedMu.Unlock()
return errPartStalled return errPartStalled
} }
case <-ctx.Done(): case <-ctx.Done():
@@ -403,8 +315,6 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
ticker := time.NewTicker(60 * time.Millisecond) ticker := time.NewTicker(60 * time.Millisecond)
for { for {
select { select {
case <-b.done:
return b.err
case <-ticker.C: case <-ticker.C:
fn(api.ProgressResponse{ fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]), Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
@@ -412,6 +322,10 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
Total: b.Total, Total: b.Total,
Completed: b.Completed.Load(), Completed: b.Completed.Load(),
}) })
if b.done || b.err != nil {
return b.err
}
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
} }

View File

@@ -32,30 +32,20 @@ import (
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
"golang.org/x/crypto/ssh"
) )
var ( var errCapabilityCompletion = errors.New("completion")
errCapabilities = errors.New("does not support")
errCapabilityCompletion = errors.New("completion")
errCapabilityTools = errors.New("tools")
errCapabilityInsert = errors.New("insert")
)
type Capability string type Capability string
const ( const CapabilityCompletion = Capability("completion")
CapabilityCompletion = Capability("completion")
CapabilityTools = Capability("tools")
CapabilityInsert = Capability("insert")
)
type registryOptions struct { type registryOptions struct {
Insecure bool Insecure bool
Username string Username string
Password string Password string
Token string Token string
CheckRedirect func(req *http.Request, via []*http.Request) error
} }
type Model struct { type Model struct {
@@ -99,15 +89,6 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok { if _, ok := ggml.KV()[fmt.Sprintf("%s.pooling_type", ggml.KV().Architecture())]; ok {
errs = append(errs, errCapabilityCompletion) errs = append(errs, errCapabilityCompletion)
} }
case CapabilityTools:
if !slices.Contains(m.Template.Vars(), "tools") {
errs = append(errs, errCapabilityTools)
}
case CapabilityInsert:
vars := m.Template.Vars()
if !slices.Contains(vars, "suffix") {
errs = append(errs, errCapabilityInsert)
}
default: default:
slog.Error("unknown capability", "capability", cap) slog.Error("unknown capability", "capability", cap)
return fmt.Errorf("unknown capability: %s", cap) return fmt.Errorf("unknown capability: %s", cap)
@@ -115,7 +96,7 @@ func (m *Model) CheckCapabilities(caps ...Capability) error {
} }
if err := errors.Join(errs...); err != nil { if err := errors.Join(errs...); err != nil {
return fmt.Errorf("%w %w", errCapabilities, errors.Join(errs...)) return fmt.Errorf("missing capabilities: %w", errors.Join(errs...))
} }
return nil return nil
@@ -441,13 +422,12 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
if err != nil { if err != nil {
return err return err
} }
tensorCount := len(baseLayer.GGML.Tensors())
ft := baseLayer.GGML.KV().FileType() ft := baseLayer.GGML.KV().FileType()
if !slices.Contains([]string{"F16", "F32"}, ft.String()) { if !slices.Contains([]string{"F16", "F32"}, ft.String()) {
return errors.New("quantization is only supported for F16 and F32 models") return errors.New("quantization is only supported for F16 and F32 models")
} else if want != ft { } else if want != ft {
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", ft, quantization)})
blob, err := GetBlobsPath(baseLayer.Digest) blob, err := GetBlobsPath(baseLayer.Digest)
if err != nil { if err != nil {
return err return err
@@ -460,7 +440,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
defer temp.Close() defer temp.Close()
defer os.Remove(temp.Name()) defer os.Remove(temp.Name())
if err := llm.Quantize(blob, temp.Name(), want); err != nil { if err := llm.Quantize(blob, temp.Name(), want, fn, tensorCount); err != nil {
return err return err
} }
@@ -493,13 +473,8 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
case "license", "template", "system":
if c.Name == "template" {
if _, err := template.Parse(c.Args); err != nil {
return fmt.Errorf("%w: %s", errBadTemplate, err)
}
}
case "license", "template", "system":
if c.Name != "license" { if c.Name != "license" {
// replace // replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool { layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
@@ -1090,11 +1065,12 @@ func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.UR
if anonymous { if anonymous {
// no user is associated with the public key, and the request requires non-anonymous access // no user is associated with the public key, and the request requires non-anonymous access
pubKey, nestedErr := auth.GetPublicKey() pubKey, nestedErr := auth.GetPublicKey()
localPubKey := strings.TrimSpace(string(ssh.MarshalAuthorizedKey(pubKey)))
if nestedErr != nil { if nestedErr != nil {
slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr)) slog.Error(fmt.Sprintf("couldn't get public key: %v", nestedErr))
return nil, errUnauthorized return nil, errUnauthorized
} }
return nil, &errtypes.UnknownOllamaKey{Key: pubKey} return nil, &errtypes.UnknownOllamaKey{Key: localPubKey}
} }
// user is associated with the public key, but is not authorized to make the request // user is associated with the public key, but is not authorized to make the request
return nil, errUnauthorized return nil, errUnauthorized
@@ -1133,9 +1109,7 @@ func makeRequest(ctx context.Context, method string, requestURL *url.URL, header
req.ContentLength = contentLength req.ContentLength = contentLength
} }
resp, err := (&http.Client{ resp, err := http.DefaultClient.Do(req)
CheckRedirect: regOpts.CheckRedirect,
}).Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -4,7 +4,6 @@ import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"context" "context"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -12,9 +11,6 @@ import (
"net/http" "net/http"
"os" "os"
"path/filepath" "path/filepath"
"slices"
"strings"
"text/template/parse"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/convert" "github.com/ollama/ollama/convert"
@@ -263,27 +259,13 @@ func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
if t, err := template.Named(s); err != nil { if t, err := template.Named(s); err != nil {
slog.Debug("template detection", "error", err) slog.Debug("template detection", "error", err)
} else { } else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template") tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil { if err != nil {
return nil, err return nil, err
} }
layer.status = fmt.Sprintf("using autodetected template %s", t.Name) tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, &layerGGML{layer, nil}) layers = append(layers, &layerGGML{tmpl, nil})
if t.Parameters != nil {
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(t.Parameters); err != nil {
return nil, err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.params")
if err != nil {
return nil, err
}
layers = append(layers, &layerGGML{layer, nil})
}
} }
} }
} }
@@ -307,113 +289,3 @@ func detectContentType(r io.Reader) (string, error) {
return "unknown", nil return "unknown", nil
} }
// parseToolCalls attempts to parse a JSON string into a slice of ToolCalls.
// mxyng: this only really works if the input contains tool calls in some JSON format
func (m *Model) parseToolCalls(s string) ([]api.ToolCall, bool) {
// create a subtree from the node that ranges over .ToolCalls
tmpl := m.Template.Subtree(func(n parse.Node) bool {
if t, ok := n.(*parse.RangeNode); ok {
return slices.Contains(template.Identifiers(t.Pipe), "ToolCalls")
}
return false
})
if tmpl == nil {
return nil, false
}
var b bytes.Buffer
if err := tmpl.Execute(&b, map[string][]api.ToolCall{
"ToolCalls": {
{
Function: api.ToolCallFunction{
Name: "@@name@@",
Arguments: api.ToolCallFunctionArguments{
"@@argument@@": 1,
},
},
},
},
}); err != nil {
return nil, false
}
var kv map[string]any
// execute the subtree with placeholders to identify the keys
// trim any commands that might exist in the template
if err := json.Unmarshal(bytes.TrimSuffix(b.Bytes(), []byte(",")), &kv); err != nil {
return nil, false
}
// find the keys that correspond to the name and arguments fields
var name, arguments string
for k, v := range kv {
switch v.(type) {
case string:
name = k
case map[string]any:
arguments = k
}
}
if name == "" || arguments == "" {
return nil, false
}
var objs []map[string]any
for offset := 0; offset < len(s); {
var obj map[string]any
decoder := json.NewDecoder(strings.NewReader(s[offset:]))
if err := decoder.Decode(&obj); errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) {
break
} else if syntax := &(json.SyntaxError{}); errors.As(err, &syntax) {
// skip over any syntax errors
offset += int(syntax.Offset)
} else if unmarshalType := &(json.UnmarshalTypeError{}); errors.As(err, &unmarshalType) {
// skip over any unmarshalable types
offset += int(unmarshalType.Offset)
} else if err != nil {
slog.Error("parseToolCalls", "error", err)
return nil, false
} else {
offset += int(decoder.InputOffset())
// collect all nested objects
var collect func(any) []map[string]any
collect = func(obj any) (all []map[string]any) {
switch o := obj.(type) {
case map[string]any:
all = append(all, o)
for _, v := range o {
all = append(all, collect(v)...)
}
case []any:
for _, v := range o {
all = append(all, collect(v)...)
}
}
return all
}
objs = append(objs, collect(obj)...)
}
}
var toolCalls []api.ToolCall
for _, kv := range objs {
n, nok := kv[name].(string)
a, aok := kv[arguments].(map[string]any)
if nok && aok {
toolCalls = append(toolCalls, api.ToolCall{
Function: api.ToolCallFunction{
Name: n,
Arguments: a,
},
})
}
}
return toolCalls, len(toolCalls) > 0
}

View File

@@ -3,9 +3,7 @@ package server
import ( import (
"archive/zip" "archive/zip"
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt"
"io" "io"
"os" "os"
"path/filepath" "path/filepath"
@@ -13,9 +11,7 @@ import (
"strings" "strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template"
) )
func createZipFile(t *testing.T, name string) *os.File { func createZipFile(t *testing.T, name string) *os.File {
@@ -114,123 +110,3 @@ func TestExtractFromZipFile(t *testing.T) {
}) })
} }
} }
func readFile(t *testing.T, base, name string) *bytes.Buffer {
t.Helper()
bts, err := os.ReadFile(filepath.Join(base, name))
if err != nil {
t.Fatal(err)
}
return bytes.NewBuffer(bts)
}
func TestExecuteWithTools(t *testing.T) {
p := filepath.Join("testdata", "tools")
cases := []struct {
model string
output string
ok bool
}{
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", `[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]
The temperature in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.`, true},
{"mistral", `I'm not aware of that information. However, I can suggest searching for the weather using the "get_current_weather" function:
[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"mistral", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"command-r-plus", "Action: ```json" + `
[
{
"tool_name": "get_current_weather",
"parameters": {
"format": "fahrenheit",
"location": "San Francisco, CA"
}
},
{
"tool_name": "get_current_weather",
"parameters": {
"format": "celsius",
"location": "Toronto, Canada"
}
}
]
` + "```", true},
{"command-r-plus", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"firefunction", ` functools[{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]`, true},
{"firefunction", " The weather in San Francisco, CA is 70°F and in Toronto, Canada is 20°C.", false},
{"llama3-groq-tool-use", `<tool_call>
{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}}
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}
</tool_call>`, true},
{"xlam", `{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"fahrenheit","location":"San Francisco, CA"}},{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Toronto, Canada"}}]}`, true},
}
var tools []api.Tool
if err := json.Unmarshal(readFile(t, p, "tools.json").Bytes(), &tools); err != nil {
t.Fatal(err)
}
var messages []api.Message
if err := json.Unmarshal(readFile(t, p, "messages.json").Bytes(), &messages); err != nil {
t.Fatal(err)
}
calls := []api.ToolCall{
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "fahrenheit",
"location": "San Francisco, CA",
},
},
},
{
Function: api.ToolCallFunction{
Name: "get_current_weather",
Arguments: api.ToolCallFunctionArguments{
"format": "celsius",
"location": "Toronto, Canada",
},
},
},
}
for _, tt := range cases {
t.Run(tt.model, func(t *testing.T) {
tmpl, err := template.Parse(readFile(t, p, fmt.Sprintf("%s.gotmpl", tt.model)).String())
if err != nil {
t.Fatal(err)
}
t.Run("template", func(t *testing.T) {
var actual bytes.Buffer
if err := tmpl.Execute(&actual, template.Values{Tools: tools, Messages: messages}); err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(actual.String(), readFile(t, p, fmt.Sprintf("%s.out", tt.model)).String()); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("parse", func(t *testing.T) {
m := &Model{Template: tmpl}
actual, ok := m.parseToolCalls(tt.output)
if ok != tt.ok {
t.Fatalf("expected %t, got %t", tt.ok, ok)
}
if tt.ok {
if diff := cmp.Diff(actual, calls); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
}
})
})
}
}

View File

@@ -4,6 +4,7 @@ import (
"bytes" "bytes"
"context" "context"
"log/slog" "log/slog"
"slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@@ -15,21 +16,29 @@ type tokenizeFunc func(context.Context, string) ([]int, error)
// chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn. // chatPrompt accepts a list of messages and returns the prompt and images that should be used for the next chat turn.
// chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the // chatPrompt truncates any messages that exceed the context window of the model, making sure to always include 1) the
// latest message and 2) system messages // latest message and 2) system messages
func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message, tools []api.Tool) (prompt string, images []llm.ImageData, _ error) { func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.Options, msgs []api.Message) (prompt string, images []llm.ImageData, _ error) {
// pull out any system messages which should always be included in the prompt
var system []api.Message var system []api.Message
msgs = slices.DeleteFunc(msgs, func(m api.Message) bool {
if m.Role == "system" {
system = append(system, m)
return true
}
return false
})
if len(system) == 0 && m.System != "" {
// add model system prompt since it wasn't provided
system = append(system, api.Message{Role: "system", Content: m.System})
}
// always include the last message // always include the last message
n := len(msgs) - 1 n := len(msgs) - 1
// in reverse, find all messages that fit into context window // in reverse, find all messages that fit into context window
for i := n - 1; i >= 0; i-- { for i := n - 1; i >= 0; i-- {
system = make([]api.Message, 0)
for j := range i {
if msgs[j].Role == "system" {
system = append(system, msgs[j])
}
}
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[i:]...)}); err != nil {
return "", nil, err return "", nil, err
} }
@@ -57,7 +66,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
// truncate any messages that do not fit into the context window // truncate any messages that do not fit into the context window
var b bytes.Buffer var b bytes.Buffer
if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...), Tools: tools}); err != nil { if err := m.Template.Execute(&b, template.Values{Messages: append(system, msgs[n:]...)}); err != nil {
return "", nil, err return "", nil, err
} }

View File

@@ -3,13 +3,21 @@ package server
import ( import (
"bytes" "bytes"
"context" "context"
"strings"
"testing" "testing"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/template" "github.com/ollama/ollama/template"
) )
func tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func TestChatPrompt(t *testing.T) { func TestChatPrompt(t *testing.T) {
type expect struct { type expect struct {
prompt string prompt string
@@ -152,19 +160,6 @@ func TestChatPrompt(t *testing.T) {
{Role: "assistant", Content: "I-I'm a what?"}, {Role: "assistant", Content: "I-I'm a what?"},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."}, {Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
}, },
expect: expect{
prompt: "You are the Test Who Lived. You're a test, Harry! I-I'm a what? A test. And a thumping good one at that, I'd wager. ",
},
},
{
name: "out of order system",
limit: 2048,
msgs: []api.Message{
{Role: "user", Content: "You're a test, Harry!"},
{Role: "assistant", Content: "I-I'm a what?"},
{Role: "system", Content: "You are the Test Who Lived."},
{Role: "user", Content: "A test. And a thumping good one at that, I'd wager."},
},
expect: expect{ expect: expect{
prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ", prompt: "You're a test, Harry! I-I'm a what? You are the Test Who Lived. A test. And a thumping good one at that, I'd wager. ",
}, },
@@ -183,13 +178,13 @@ func TestChatPrompt(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}} model := Model{Template: tmpl, ProjectorPaths: []string{"vision"}}
opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}} opts := api.Options{Runner: api.Runner{NumCtx: tt.limit}}
prompt, images, err := chatPrompt(context.TODO(), &model, mockRunner{}.Tokenize, &opts, tt.msgs, nil) prompt, images, err := chatPrompt(context.TODO(), &model, tokenize, &opts, tt.msgs)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
if diff := cmp.Diff(prompt, tt.prompt); diff != "" { if tt.prompt != prompt {
t.Errorf("mismatch (-got +want):\n%s", diff) t.Errorf("expected %q, got %q", tt.prompt, prompt)
} }
if len(images) != len(tt.images) { if len(images) != len(tt.images) {

View File

@@ -4,12 +4,13 @@ import (
"bytes" "bytes"
"cmp" "cmp"
"context" "context"
"encoding/base64"
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"log"
"log/slog" "log/slog"
"math"
"net" "net"
"net/http" "net/http"
"net/netip" "net/netip"
@@ -23,8 +24,10 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/crypto/ssh"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
@@ -56,7 +59,6 @@ func init() {
} }
var errRequired = errors.New("is required") var errRequired = errors.New("is required")
var errBadTemplate = errors.New("template error")
func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) { func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options, error) {
opts := api.DefaultOptions() opts := api.DefaultOptions()
@@ -104,7 +106,6 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []Capabil
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -123,10 +124,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if req.Suffix != "" {
caps = append(caps, CapabilityInsert)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support generate", req.Model)})
@@ -136,8 +133,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if req.Prompt == "" { if req.Prompt == "" {
c.JSON(http.StatusOK, api.GenerateResponse{ c.JSON(http.StatusOK, api.GenerateResponse{
Model: req.Model, Model: req.Model,
@@ -155,6 +150,19 @@ func (s *Server) GenerateHandler(c *gin.Context) {
prompt := req.Prompt prompt := req.Prompt
if !req.Raw { if !req.Raw {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
msgs = append(msgs, api.Message{Role: "user", Content: req.Prompt})
tmpl := m.Template tmpl := m.Template
if req.Template != "" { if req.Template != "" {
tmpl, err = template.Parse(req.Template) tmpl, err = template.Parse(req.Template)
@@ -175,26 +183,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
b.WriteString(s) b.WriteString(s)
} }
var values template.Values if err := tmpl.Execute(&b, template.Values{Messages: msgs}); err != nil {
if req.Suffix != "" {
values.Prompt = prompt
values.Suffix = req.Suffix
} else {
var msgs []api.Message
if req.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: req.System})
} else if m.System != "" {
msgs = append(msgs, api.Message{Role: "system", Content: m.System})
}
for _, i := range images {
msgs = append(msgs, api.Message{Role: "user", Content: fmt.Sprintf("[img-%d]", i.ID)})
}
values.Messages = append(msgs, api.Message{Role: "user", Content: req.Prompt})
}
if err := tmpl.Execute(&b, values); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
@@ -206,48 +195,26 @@ func (s *Server) GenerateHandler(c *gin.Context) {
ch := make(chan any) ch := make(chan any)
go func() { go func() {
// TODO (jmorganca): avoid building the response twice both here and below
var sb strings.Builder
defer close(ch) defer close(ch)
if err := r.Completion(c.Request.Context(), llm.CompletionRequest{ if err := r.Completion(c.Request.Context(), llm.CompletionRequest{
Prompt: prompt, Prompt: prompt,
Images: images, Images: images,
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(cr llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.GenerateResponse{ ch <- api.GenerateResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Response: cr.Content, Response: r.Content,
Done: cr.Done, Done: r.Done,
DoneReason: cr.DoneReason, DoneReason: r.DoneReason,
Metrics: api.Metrics{ Metrics: api.Metrics{
PromptEvalCount: cr.PromptEvalCount, PromptEvalCount: r.PromptEvalCount,
PromptEvalDuration: cr.PromptEvalDuration, PromptEvalDuration: r.PromptEvalDuration,
EvalCount: cr.EvalCount, EvalCount: r.EvalCount,
EvalDuration: cr.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
if _, err := sb.WriteString(cr.Content); err != nil {
ch <- gin.H{"error": err.Error()}
}
if cr.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
tokens, err := r.Tokenize(c.Request.Context(), prompt+sb.String())
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
res.Context = append(req.Context, tokens...)
}
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
@@ -283,121 +250,6 @@ func (s *Server) GenerateHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func (s *Server) EmbedHandler(c *gin.Context) {
var req api.EmbedRequest
err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return
case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
truncate := true
if req.Truncate != nil && !*req.Truncate {
truncate = false
}
var input []string
switch i := req.Input.(type) {
case string:
if len(i) > 0 {
input = append(input, i)
}
case []any:
for _, v := range i {
if _, ok := v.(string); !ok {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
input = append(input, v.(string))
}
default:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid input type"})
return
}
if len(input) == 0 {
c.JSON(http.StatusOK, api.EmbedResponse{Model: req.Model, Embeddings: [][]float32{}})
return
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, []Capability{}, req.Options, req.KeepAlive)
if err != nil {
handleScheduleError(c, req.Model, err)
return
}
kvData, err := getKVData(m.ModelPath, false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
for i, s := range input {
tokens, err := r.Tokenize(c.Request.Context(), s)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
ctxLen := min(opts.NumCtx, int(kvData.ContextLength()))
if len(tokens) > ctxLen {
if !truncate {
c.JSON(http.StatusBadRequest, gin.H{"error": "input length exceeds maximum context length"})
return
}
tokens = tokens[:ctxLen]
s, err = r.Detokenize(c.Request.Context(), tokens)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
}
input[i] = s
}
embeddings, err := r.Embed(c.Request.Context(), input)
if err != nil {
slog.Error("embedding generation failed", "error", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
for i, e := range embeddings {
embeddings[i] = normalize(e)
}
resp := api.EmbedResponse{
Model: req.Model,
Embeddings: embeddings,
}
c.JSON(http.StatusOK, resp)
}
func normalize(vec []float32) []float32 {
var sum float32
for _, v := range vec {
sum += v * v
}
norm := float32(0.0)
if sum > 0 {
norm = float32(1.0 / math.Sqrt(float64(sum)))
}
for i := range vec {
vec[i] *= norm
}
return vec
}
func (s *Server) EmbeddingsHandler(c *gin.Context) { func (s *Server) EmbeddingsHandler(c *gin.Context) {
var req api.EmbeddingRequest var req api.EmbeddingRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
@@ -420,24 +272,14 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
return return
} }
embeddings, err := r.Embed(c.Request.Context(), []string{req.Prompt}) embedding, err := r.Embedding(c.Request.Context(), req.Prompt)
if err != nil { if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err)) slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return return
} }
embedding := make([]float64, len(embeddings[0])) c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: embedding})
for i, v := range embeddings[0] {
embedding[i] = float64(v)
}
resp := api.EmbeddingResponse{
Embedding: embedding,
}
c.JSON(http.StatusOK, resp)
} }
func (s *Server) PullModelHandler(c *gin.Context) { func (s *Server) PullModelHandler(c *gin.Context) {
@@ -609,9 +451,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
defer cancel() defer cancel()
quantization := cmp.Or(r.Quantize, r.Quantization) quantization := cmp.Or(r.Quantize, r.Quantization)
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); errors.Is(err, errBadTemplate) { if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
} else if err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -713,6 +553,13 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
m.System = req.System m.System = req.System
} }
if req.Template != "" {
m.Template, err = template.Parse(req.Template)
if err != nil {
return nil, err
}
}
msgs := make([]api.Message, len(m.Messages)) msgs := make([]api.Message, len(m.Messages))
for i, msg := range m.Messages { for i, msg := range m.Messages {
msgs[i] = api.Message{Role: msg.Role, Content: msg.Content} msgs[i] = api.Message{Role: msg.Role, Content: msg.Content}
@@ -927,7 +774,6 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
_, err = os.Stat(path) _, err = os.Stat(path)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
@@ -940,6 +786,12 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
return return
} }
if c.GetHeader("X-Redirect-Create") == "1" && s.IsLocal(c) {
c.Header("LocalLocation", path)
c.Status(http.StatusTemporaryRedirect)
return
}
layer, err := NewLayer(c.Request.Body, "") layer, err := NewLayer(c.Request.Body, "")
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -954,6 +806,54 @@ func (s *Server) CreateBlobHandler(c *gin.Context) {
c.Status(http.StatusCreated) c.Status(http.StatusCreated)
} }
func (s *Server) IsLocal(c *gin.Context) bool {
if authz := c.GetHeader("Authorization"); authz != "" {
parts := strings.Split(authz, ":")
if len(parts) != 3 {
return false
}
clientPublicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(fmt.Sprintf("ssh-ed25519 %s", parts[0])))
if err != nil {
return false
}
// partialRequestData is formatted as http.Method,http.requestURI,timestamp,nonce
requestData, err := base64.StdEncoding.DecodeString(parts[1])
if err != nil {
return false
}
partialRequestDataParts := strings.Split(string(requestData), ",")
if len(partialRequestDataParts) != 3 {
return false
}
signature, err := base64.StdEncoding.DecodeString(parts[2])
if err != nil {
return false
}
if err := clientPublicKey.Verify(requestData, &ssh.Signature{Format: clientPublicKey.Type(), Blob: signature}); err != nil {
return false
}
serverPublicKey, err := auth.GetPublicKey()
if err != nil {
log.Fatal(err)
}
if bytes.Equal(serverPublicKey.Marshal(), clientPublicKey.Marshal()) {
return true
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "unauthorized"})
return false
}
return false
}
func isLocalIP(ip netip.Addr) bool { func isLocalIP(ip netip.Addr) bool {
if interfaces, err := net.Interfaces(); err == nil { if interfaces, err := net.Interfaces(); err == nil {
for _, iface := range interfaces { for _, iface := range interfaces {
@@ -1058,7 +958,6 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/pull", s.PullModelHandler) r.POST("/api/pull", s.PullModelHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateModelHandler) r.POST("/api/create", s.CreateModelHandler)
r.POST("/api/push", s.PushModelHandler) r.POST("/api/push", s.PushModelHandler)
@@ -1072,7 +971,6 @@ func (s *Server) GenerateRoutes() http.Handler {
// Compatibility endpoints // Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListModelsHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowModelHandler)
@@ -1199,15 +1097,11 @@ func waitForStream(c *gin.Context, ch chan interface{}) {
return return
} }
case gin.H: case gin.H:
status, ok := r["status"].(int)
if !ok {
status = http.StatusInternalServerError
}
if errorMsg, ok := r["error"].(string); ok { if errorMsg, ok := r["error"].(string); ok {
c.JSON(status, gin.H{"error": errorMsg}) c.JSON(http.StatusInternalServerError, gin.H{"error": errorMsg})
return return
} else { } else {
c.JSON(status, gin.H{"error": "unexpected error format in progress response"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "unexpected error format in progress response"})
return return
} }
default: default:
@@ -1285,8 +1179,6 @@ func (s *Server) ProcessHandler(c *gin.Context) {
} }
func (s *Server) ChatHandler(c *gin.Context) { func (s *Server) ChatHandler(c *gin.Context) {
checkpointStart := time.Now()
var req api.ChatRequest var req api.ChatRequest
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
@@ -1297,10 +1189,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
caps := []Capability{CapabilityCompletion} caps := []Capability{CapabilityCompletion}
if len(req.Tools) > 0 {
caps = append(caps, CapabilityTools)
}
r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive) r, m, opts, err := s.scheduleRunner(c.Request.Context(), req.Model, caps, req.Options, req.KeepAlive)
if errors.Is(err, errCapabilityCompletion) { if errors.Is(err, errCapabilityCompletion) {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)}) c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("%q does not support chat", req.Model)})
@@ -1310,8 +1198,6 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
checkpointLoaded := time.Now()
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{ c.JSON(http.StatusOK, api.ChatResponse{
Model: req.Model, Model: req.Model,
@@ -1323,11 +1209,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
return return
} }
if req.Messages[0].Role != "system" && m.System != "" { prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages)
req.Messages = append([]api.Message{{Role: "system", Content: m.System}}, req.Messages...)
}
prompt, images, err := chatPrompt(c.Request.Context(), m, r.Tokenize, opts, req.Messages, req.Tools)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -1344,7 +1226,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
Format: req.Format, Format: req.Format,
Options: opts, Options: opts,
}, func(r llm.CompletionResponse) { }, func(r llm.CompletionResponse) {
res := api.ChatResponse{ ch <- api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),
Message: api.Message{Role: "assistant", Content: r.Content}, Message: api.Message{Role: "assistant", Content: r.Content},
@@ -1357,26 +1239,19 @@ func (s *Server) ChatHandler(c *gin.Context) {
EvalDuration: r.EvalDuration, EvalDuration: r.EvalDuration,
}, },
} }
if r.Done {
res.TotalDuration = time.Since(checkpointStart)
res.LoadDuration = checkpointLoaded.Sub(checkpointStart)
}
ch <- res
}); err != nil { }); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if (req.Stream != nil && !*req.Stream) || ((req.Stream == nil || *req.Stream) && len(req.Tools) > 0) { if req.Stream != nil && !*req.Stream {
var resp api.ChatResponse var r api.ChatResponse
var sb strings.Builder var sb strings.Builder
for rr := range ch { for rr := range ch {
switch t := rr.(type) { switch t := rr.(type) {
case api.ChatResponse: case api.ChatResponse:
sb.WriteString(t.Message.Content) sb.WriteString(t.Message.Content)
resp = t r = t
case gin.H: case gin.H:
msg, ok := t["error"].(string) msg, ok := t["error"].(string)
if !ok { if !ok {
@@ -1391,36 +1266,8 @@ func (s *Server) ChatHandler(c *gin.Context) {
} }
} }
resp.Message.Content = sb.String() r.Message.Content = sb.String()
c.JSON(http.StatusOK, r)
if len(req.Tools) > 0 {
if toolCalls, ok := m.parseToolCalls(sb.String()); ok {
resp.Message.ToolCalls = toolCalls
resp.Message.Content = ""
}
}
if (req.Stream == nil || *req.Stream) && len(resp.Message.ToolCalls) > 0 {
toolCh := make(chan any)
go func() {
defer close(toolCh)
toolCalls := resp.Message.ToolCalls
for _, toolCall := range toolCalls {
toolCh <- api.ChatResponse{
Model: resp.Model,
CreatedAt: resp.CreatedAt,
Message: api.Message{Role: "assistant", ToolCalls: []api.ToolCall{toolCall}},
}
}
resp.Message.ToolCalls = nil
resp.DoneReason = "tool_calls"
toolCh <- resp
}()
streamResponse(c, toolCh)
return
}
c.JSON(http.StatusOK, resp)
return return
} }
@@ -1429,7 +1276,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
func handleScheduleError(c *gin.Context, name string, err error) { func handleScheduleError(c *gin.Context, name string, err error) {
switch { switch {
case errors.Is(err, errCapabilities), errors.Is(err, errRequired): case errors.Is(err, errRequired):
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
c.JSON(499, gin.H{"error": "request canceled"}) c.JSON(499, gin.H{"error": "request canceled"})

View File

@@ -85,8 +85,6 @@ func checkFileExists(t *testing.T, p string, expect []string) {
} }
func TestCreateFromBin(t *testing.T) { func TestCreateFromBin(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -113,8 +111,6 @@ func TestCreateFromBin(t *testing.T) {
} }
func TestCreateFromModel(t *testing.T) { func TestCreateFromModel(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -156,8 +152,6 @@ func TestCreateFromModel(t *testing.T) {
} }
func TestCreateRemovesLayers(t *testing.T) { func TestCreateRemovesLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -205,8 +199,6 @@ func TestCreateRemovesLayers(t *testing.T) {
} }
func TestCreateUnsetsSystem(t *testing.T) { func TestCreateUnsetsSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -263,8 +255,6 @@ func TestCreateUnsetsSystem(t *testing.T) {
} }
func TestCreateMergeParameters(t *testing.T) { func TestCreateMergeParameters(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -368,8 +358,6 @@ func TestCreateMergeParameters(t *testing.T) {
} }
func TestCreateReplacesMessages(t *testing.T) { func TestCreateReplacesMessages(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -446,8 +434,6 @@ func TestCreateReplacesMessages(t *testing.T) {
} }
func TestCreateTemplateSystem(t *testing.T) { func TestCreateTemplateSystem(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -491,47 +477,9 @@ func TestCreateTemplateSystem(t *testing.T) {
if string(system) != "Say bye!" { if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system) t.Errorf("expected \"Say bye!\", actual %s", system)
} }
t.Run("incomplete template", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with unclosed if", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ if .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
t.Run("template with undefined function", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status code 400, actual %d", w.Code)
}
})
} }
func TestCreateLicenses(t *testing.T) { func TestCreateLicenses(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -578,8 +526,6 @@ func TestCreateLicenses(t *testing.T) {
} }
func TestCreateDetectTemplate(t *testing.T) { func TestCreateDetectTemplate(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -599,10 +545,9 @@ func TestCreateDetectTemplate(t *testing.T) {
} }
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{ checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-0d79f567714c62c048378f2107fb332dabee0135d080c302d884317da9433cc5"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"), filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
filepath.Join(p, "blobs", "sha256-c608dc615584cd20d9d830363dabf8a4783ae5d34245c3d8c115edb3bc7b28e4"), filepath.Join(p, "blobs", "sha256-9512c372dfc7d84d6065b8dd2b601aeed8cc1a78e7a7aa784a42fff37f5524b7"),
filepath.Join(p, "blobs", "sha256-ea34c57ba5b78b740aafe2aeb74dc6507fc3ad14170b64c26a04fb9e36c88d75"), filepath.Join(p, "blobs", "sha256-b8b78cb8c6eefd14c06f1af042e6161255bf87bbf2dd14fce57cdac893db8139"),
}) })
}) })

View File

@@ -8,15 +8,12 @@ import (
"path/filepath" "path/filepath"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
) )
func TestDelete(t *testing.T) { func TestDelete(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
envconfig.LoadConfig() envconfig.LoadConfig()
@@ -80,8 +77,6 @@ func TestDelete(t *testing.T) {
} }
func TestDeleteDuplicateLayers(t *testing.T) { func TestDeleteDuplicateLayers(t *testing.T) {
gin.SetMode(gin.TestMode)
p := t.TempDir() p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p) t.Setenv("OLLAMA_MODELS", p)
var s Server var s Server

View File

@@ -1,714 +0,0 @@
package server
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/google/go-cmp/cmp"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm"
)
type mockRunner struct {
llm.LlamaServer
// CompletionRequest is only valid until the next call to Completion
llm.CompletionRequest
llm.CompletionResponse
}
func (m *mockRunner) Completion(_ context.Context, r llm.CompletionRequest, fn func(r llm.CompletionResponse)) error {
m.CompletionRequest = r
fn(m.CompletionResponse)
return nil
}
func (mockRunner) Tokenize(_ context.Context, s string) (tokens []int, err error) {
for range strings.Fields(s) {
tokens = append(tokens, len(tokens))
}
return
}
func newMockServer(mock *mockRunner) func(gpu.GpuInfoList, string, *llm.GGML, []string, []string, api.Options, int) (llm.LlamaServer, error) {
return func(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, projectors, system []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return mock, nil
}
}
func TestGenerateChat(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities chat", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support chat"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.ChatResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkChatResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.ChatResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if diff := cmp.Diff(actual.Message, api.Message{
Role: "assistant",
Content: content,
}); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("messages", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("messages with model system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("messages with system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Hello!"},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("messages with interleaved system", func(t *testing.T) {
w := createRequest(t, s.ChatHandler, api.ChatRequest{
Model: "test-system",
Messages: []api.Message{
{Role: "user", Content: "Hello!"},
{Role: "assistant", Content: "I can help you with that."},
{Role: "system", Content: "You can perform magic tricks."},
{Role: "user", Content: "Help me write tests."},
},
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! Assistant: I can help you with that. System: You can perform magic tricks. User: Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkChatResponse(t, w.Body, "test-system", "Abra kadabra!")
})
}
func TestGenerate(t *testing.T) {
gin.SetMode(gin.TestMode)
mock := mockRunner{
CompletionResponse: llm.CompletionResponse{
Done: true,
DoneReason: "stop",
PromptEvalCount: 1,
PromptEvalDuration: 1,
EvalCount: 1,
EvalDuration: 1,
},
}
s := Server{
sched: &Scheduler{
pendingReqCh: make(chan *LlmRequest, 1),
finishedReqCh: make(chan *LlmRequest, 1),
expiredCh: make(chan *runnerRef, 1),
unloadedCh: make(chan any, 1),
loaded: make(map[string]*runnerRef),
newServerFn: newMockServer(&mock),
getGpuFn: gpu.GetGPUInfo,
getCpuFn: gpu.GetCPUInfo,
reschedDelay: 250 * time.Millisecond,
loadFn: func(req *LlmRequest, ggml *llm.GGML, gpus gpu.GpuInfoList, numParallel int) {
// add small delay to simulate loading
time.Sleep(time.Millisecond)
req.successCh <- &runnerRef{
llama: &mock,
}
},
},
}
go s.sched.Run(context.TODO())
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test",
Modelfile: fmt.Sprintf(`FROM %s
TEMPLATE """
{{- if .System }}System: {{ .System }} {{ end }}
{{- if .Prompt }}User: {{ .Prompt }} {{ end }}
{{- if .Response }}Assistant: {{ .Response }} {{ end }}"""
`, createBinFile(t, llm.KV{
"general.architecture": "llama",
"llama.block_count": uint32(1),
"llama.context_length": uint32(8192),
"llama.embedding_length": uint32(4096),
"llama.attention.head_count": uint32(32),
"llama.attention.head_count_kv": uint32(8),
"tokenizer.ggml.tokens": []string{""},
"tokenizer.ggml.scores": []float32{0},
"tokenizer.ggml.token_type": []int32{0},
}, []llm.Tensor{
{Name: "token_embd.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_down.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_gate.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_up.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.ffn_norm.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_k.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_q.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "blk.0.attn_v.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
{Name: "output.weight", Shape: []uint64{1}, WriterTo: bytes.NewReader(make([]byte, 4))},
})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("missing body", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, nil)
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"model is required"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities generate", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "bert",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"general.architecture": "bert",
"bert.pooling_type": uint32(0),
}, []llm.Tensor{})),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
w = createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "bert",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"\"bert\" does not support generate"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("missing capabilities suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusBadRequest {
t.Errorf("expected status 400, got %d", w.Code)
}
if diff := cmp.Diff(w.Body.String(), `{"error":"test does not support insert"}`); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("load model", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
var actual api.GenerateResponse
if err := json.NewDecoder(w.Body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != "test" {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done true, got false")
}
if actual.DoneReason != "load" {
t.Errorf("expected done reason load, got %s", actual.DoneReason)
}
})
checkGenerateResponse := func(t *testing.T, body io.Reader, model, content string) {
t.Helper()
var actual api.GenerateResponse
if err := json.NewDecoder(body).Decode(&actual); err != nil {
t.Fatal(err)
}
if actual.Model != model {
t.Errorf("expected model test, got %s", actual.Model)
}
if !actual.Done {
t.Errorf("expected done false, got true")
}
if actual.DoneReason != "stop" {
t.Errorf("expected done reason stop, got %s", actual.DoneReason)
}
if actual.Response != content {
t.Errorf("expected response %s, got %s", content, actual.Response)
}
if actual.Context == nil {
t.Errorf("expected context not nil")
}
if actual.PromptEvalCount == 0 {
t.Errorf("expected prompt eval count > 0, got 0")
}
if actual.PromptEvalDuration == 0 {
t.Errorf("expected prompt eval duration > 0, got 0")
}
if actual.EvalCount == 0 {
t.Errorf("expected eval count > 0, got 0")
}
if actual.EvalDuration == 0 {
t.Errorf("expected eval duration > 0, got 0")
}
if actual.LoadDuration == 0 {
t.Errorf("expected load duration > 0, got 0")
}
if actual.TotalDuration == 0 {
t.Errorf("expected total duration > 0, got 0")
}
}
mock.CompletionResponse.Content = "Hi!"
t.Run("prompt", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test", "Hi!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-system",
Modelfile: "FROM test\nSYSTEM You are a helpful assistant.",
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with model system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You are a helpful assistant. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Hi!")
})
mock.CompletionResponse.Content = "Abra kadabra!"
t.Run("prompt with system", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Hello!",
System: "You can perform magic tricks.",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "System: You can perform magic tricks. User: Hello! "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
t.Run("prompt with template", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
System: "You can perform magic tricks.",
Template: `{{- if .System }}{{ .System }} {{ end }}
{{- if .Prompt }}### USER {{ .Prompt }} {{ end }}
{{- if .Response }}### ASSISTANT {{ .Response }} {{ end }}`,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "You can perform magic tricks. ### USER Help me write tests. "); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
checkGenerateResponse(t, w.Body, "test-system", "Abra kadabra!")
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Model: "test-suffix",
Modelfile: `FROM test
TEMPLATE """{{- if .Suffix }}<PRE> {{ .Prompt }} <SUF>{{ .Suffix }} <MID>
{{- else }}{{ .Prompt }}
{{- end }}"""`,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200, got %d", w.Code)
}
t.Run("prompt with suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
Suffix: " return c",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "<PRE> def add( <SUF> return c <MID>"); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("prompt without suffix", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-suffix",
Prompt: "def add(",
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "def add("); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
t.Run("raw", func(t *testing.T) {
w := createRequest(t, s.GenerateHandler, api.GenerateRequest{
Model: "test-system",
Prompt: "Help me write tests.",
Raw: true,
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Errorf("expected status 200, got %d", w.Code)
}
if diff := cmp.Diff(mock.CompletionRequest.Prompt, "Help me write tests."); diff != "" {
t.Errorf("mismatch (-got +want):\n%s", diff)
}
})
}

View File

@@ -7,14 +7,11 @@ import (
"slices" "slices"
"testing" "testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
) )
func TestList(t *testing.T) { func TestList(t *testing.T) {
gin.SetMode(gin.TestMode)
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
envconfig.LoadConfig() envconfig.LoadConfig()

View File

@@ -7,7 +7,6 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"math"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@@ -273,77 +272,6 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "library", retrieveResp.OwnedBy) assert.Equal(t, "library", retrieveResp.OwnedBy)
}, },
}, },
{
Name: "Embed Handler Empty Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: "",
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
var embedResp api.EmbedResponse
err = json.Unmarshal(body, &embedResp)
if err != nil {
t.Fatal(err)
}
if embedResp.Model != "t-bone" {
t.Fatalf("expected model t-bone, got %s", embedResp.Model)
}
if embedResp.Embeddings == nil {
t.Fatalf("expected embeddings to not be nil, got %v", embedResp.Embeddings)
}
if len(embedResp.Embeddings) != 0 {
t.Fatalf("expected embeddings to be empty, got %v", embedResp.Embeddings)
}
},
},
{
Name: "Embed Handler Invalid Input",
Method: http.MethodPost,
Path: "/api/embed",
Setup: func(t *testing.T, req *http.Request) {
embedReq := api.EmbedRequest{
Model: "t-bone",
Input: 2,
}
jsonData, err := json.Marshal(embedReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
if contentType != "application/json; charset=utf-8" {
t.Fatalf("expected content type application/json; charset=utf-8, got %s", contentType)
}
_, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatal(err)
}
if resp.StatusCode != http.StatusBadRequest {
t.Fatalf("expected status code 400, got %d", resp.StatusCode)
}
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -492,38 +420,3 @@ func TestShow(t *testing.T) {
t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"]) t.Fatal("Expected projector architecture to be 'clip', but got", resp.ProjectorInfo["general.architecture"])
} }
} }
func TestNormalize(t *testing.T) {
type testCase struct {
input []float32
}
testCases := []testCase{
{input: []float32{1}},
{input: []float32{0, 1, 2, 3}},
{input: []float32{0.1, 0.2, 0.3}},
{input: []float32{-0.1, 0.2, 0.3, -0.4}},
{input: []float32{0, 0, 0}},
}
isNormalized := func(vec []float32) (res bool) {
sum := 0.0
for _, v := range vec {
sum += float64(v * v)
}
if math.Abs(sum-1) > 1e-6 {
return sum == 0
} else {
return true
}
}
for _, tc := range testCases {
t.Run("", func(t *testing.T) {
normalized := normalize(tc.input)
if !isNormalized(normalized) {
t.Errorf("Vector %v is not normalized", tc.input)
}
})
}
}

View File

@@ -135,6 +135,11 @@ func (s *Scheduler) processPending(ctx context.Context) {
} }
for { for {
cpus := s.getCpuFn()
var systemMem gpu.GpuInfo
if len(cpus) > 0 {
systemMem = cpus[0]
}
var runnerToExpire *runnerRef var runnerToExpire *runnerRef
s.loadedMu.Lock() s.loadedMu.Lock()
runner := s.loaded[pending.model.ModelPath] runner := s.loaded[pending.model.ModelPath]
@@ -188,6 +193,38 @@ func (s *Scheduler) processPending(ctx context.Context) {
break break
} }
estimate := llm.EstimateGPULayers(gpus, ggml, pending.model.ProjectorPaths, pending.opts)
maxSize := systemMem.FreeMemory
// Add available GPU memory to the total pool
// macOS hardware has unified memory so don't double count
if runtime.GOOS != "darwin" {
for _, gpu := range gpus {
if gpu.Library == "cpu" {
continue
}
if loadedCount == 0 {
// If no other models are loaded, set the limit based on what's available
maxSize += gpu.FreeMemory
} else {
// Other models could be unloaded, favor total memory for limit
maxSize += gpu.TotalMemory
}
}
}
// Block attempting to load a model larger than system memory + GPU memory
if estimate.TotalSize > maxSize {
slog.Warn("model request too large for system", "requested", format.HumanBytes2(estimate.TotalSize), "system", format.HumanBytes2(maxSize))
// Linux will crash if over-allocating memory - return an error to the user.
// TODO (jmorganca): add reasonable upper limits for darwin and windows as well
if runtime.GOOS == "linux" {
pending.errCh <- fmt.Errorf("requested model (%s) is too large for this system (%s)", format.HumanBytes2(estimate.TotalSize), format.HumanBytes2(maxSize))
break
}
}
// Evaluate if the model will fit in the available system memory, or if we should unload a model first // Evaluate if the model will fit in the available system memory, or if we should unload a model first
if len(gpus) == 1 && gpus[0].Library == "cpu" { if len(gpus) == 1 && gpus[0].Library == "cpu" {
// simplifying assumption of defaultParallel when in CPU mode // simplifying assumption of defaultParallel when in CPU mode

View File

@@ -94,7 +94,7 @@ func TestLoad(t *testing.T) {
require.Len(t, s.expiredCh, 1) require.Len(t, s.expiredCh, 1)
} }
type reqBundle struct { type bundle struct {
ctx context.Context //nolint:containedctx ctx context.Context //nolint:containedctx
ctxDone func() ctxDone func()
srv *mockLlm srv *mockLlm
@@ -102,13 +102,13 @@ type reqBundle struct {
ggml *llm.GGML ggml *llm.GGML
} }
func (scenario *reqBundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { func (scenario *bundle) newServer(gpus gpu.GpuInfoList, model string, ggml *llm.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) {
return scenario.srv, nil return scenario.srv, nil
} }
func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle { func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64) *bundle {
b := &reqBundle{} scenario := &bundle{}
b.ctx, b.ctxDone = context.WithCancel(ctx) scenario.ctx, scenario.ctxDone = context.WithCancel(ctx)
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
@@ -135,154 +135,124 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
b.ggml, err = llm.LoadModel(model.ModelPath, 0) scenario.ggml, err = llm.LoadModel(model.ModelPath, 0)
require.NoError(t, err) require.NoError(t, err)
if duration == nil { scenario.req = &LlmRequest{
duration = &api.Duration{Duration: 5 * time.Millisecond} ctx: scenario.ctx,
}
b.req = &LlmRequest{
ctx: b.ctx,
model: model, model: model,
opts: api.DefaultOptions(), opts: api.DefaultOptions(),
sessionDuration: duration, sessionDuration: &api.Duration{Duration: 5 * time.Millisecond},
successCh: make(chan *runnerRef, 1), successCh: make(chan *runnerRef, 1),
errCh: make(chan error, 1), errCh: make(chan error, 1),
} }
b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} scenario.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}}
return b return scenario
} }
func getGpuFn() gpu.GpuInfoList { func TestRequests(t *testing.T) {
g := gpu.GpuInfo{Library: "metal"} ctx, done := context.WithTimeout(context.Background(), 10*time.Second)
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
func getCpuFn() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
func TestRequestsSameModelSameRequest(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done() defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 11, &api.Duration{Duration: 0})
b.req.model = a.req.model
b.ggml = a.ggml
s.newServerFn = a.newServer // Same model, same request
slog.Info("a") scenario1a := newScenario(t, ctx, "ollama-model-1", 10)
s.pendingReqCh <- a.req scenario1a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
scenario1b := newScenario(t, ctx, "ollama-model-1", 11)
scenario1b.req.model = scenario1a.req.model
scenario1b.ggml = scenario1a.ggml
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
// simple reload of same model
scenario2a := newScenario(t, ctx, "ollama-model-1", 20)
tmpModel := *scenario1a.req.model
scenario2a.req.model = &tmpModel
scenario2a.ggml = scenario1a.ggml
scenario2a.req.sessionDuration = &api.Duration{Duration: 5 * time.Millisecond}
// Multiple loaded models
scenario3a := newScenario(t, ctx, "ollama-model-3a", 1*format.GigaByte)
scenario3b := newScenario(t, ctx, "ollama-model-3b", 24*format.GigaByte)
scenario3c := newScenario(t, ctx, "ollama-model-4a", 30)
scenario3c.req.opts.NumGPU = 0 // CPU load, will be allowed
scenario3d := newScenario(t, ctx, "ollama-model-3c", 30) // Needs prior unloaded
s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"}
g.TotalMemory = 24 * format.GigaByte
g.FreeMemory = 12 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.getCpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "cpu"}
g.TotalMemory = 32 * format.GigaByte
g.FreeMemory = 26 * format.GigaByte
return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-a.req.successCh: case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh) require.Empty(t, scenario1a.req.errCh)
case err := <-a.req.errCh: case err := <-scenario1a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Same runner as first request due to not needing a reload // Same runner as first request due to not needing a reload
s.newServerFn = b.newServer s.newServerFn = scenario1b.newServer
slog.Info("b") slog.Info("scenario1b")
s.pendingReqCh <- b.req s.pendingReqCh <- scenario1b.req
select { select {
case resp := <-b.req.successCh: case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh) require.Empty(t, scenario1b.req.errCh)
case err := <-b.req.errCh: case err := <-scenario1b.req.errCh:
t.Fatal(err.Error())
case <-ctx.Done():
t.Fatal("timeout")
}
}
func TestRequestsSimpleReloadSameModel(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond})
b := newScenarioRequest(t, ctx, "ollama-model-1", 20, &api.Duration{Duration: 5 * time.Millisecond})
tmpModel := *a.req.model
b.req.model = &tmpModel
b.ggml = a.ggml
s.newServerFn = a.newServer
slog.Info("a")
s.pendingReqCh <- a.req
require.Len(t, s.pendingReqCh, 1)
s.Run(ctx)
select {
case resp := <-a.req.successCh:
require.Equal(t, resp.llama, a.srv)
require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh)
case err := <-a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
// Trigger a reload // Trigger a reload
s.newServerFn = b.newServer s.newServerFn = scenario2a.newServer
b.req.model.AdapterPaths = []string{"new"} scenario2a.req.model.AdapterPaths = []string{"new"}
slog.Info("b") slog.Info("scenario2a")
s.pendingReqCh <- b.req s.pendingReqCh <- scenario2a.req
// finish first two requests, so model can reload // finish first two requests, so model can reload
time.Sleep(1 * time.Millisecond) time.Sleep(1 * time.Millisecond)
a.ctxDone() scenario1a.ctxDone()
scenario1b.ctxDone()
select { select {
case resp := <-b.req.successCh: case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, b.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh) require.Empty(t, scenario2a.req.errCh)
case err := <-b.req.errCh: case err := <-scenario2a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
}
func TestRequestsMultipleLoadedModels(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 500*time.Millisecond)
defer done()
s := InitScheduler(ctx)
s.getGpuFn = getGpuFn
s.getCpuFn = getCpuFn
// Multiple loaded models
a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil)
b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil)
c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil)
c.req.opts.NumGPU = 0 // CPU load, will be allowed
d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded
envconfig.MaxRunners = 1 envconfig.MaxRunners = 1
s.newServerFn = a.newServer s.newServerFn = scenario3a.newServer
slog.Info("a") slog.Info("scenario3a")
s.pendingReqCh <- a.req s.pendingReqCh <- scenario3a.req
s.Run(ctx) // finish prior request, so new model can load
time.Sleep(1 * time.Millisecond)
scenario2a.ctxDone()
select { select {
case resp := <-a.req.successCh: case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, a.req.errCh) require.Empty(t, scenario3a.req.errCh)
case err := <-a.req.errCh: case err := <-scenario3a.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -292,15 +262,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
envconfig.MaxRunners = 0 envconfig.MaxRunners = 0
s.newServerFn = b.newServer s.newServerFn = scenario3b.newServer
slog.Info("b") slog.Info("scenario3b")
s.pendingReqCh <- b.req s.pendingReqCh <- scenario3b.req
select { select {
case resp := <-b.req.successCh: case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, b.req.errCh) require.Empty(t, scenario3b.req.errCh)
case err := <-b.req.errCh: case err := <-scenario3b.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -310,15 +280,15 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// This is a CPU load with NumGPU = 0 so it should load // This is a CPU load with NumGPU = 0 so it should load
s.newServerFn = c.newServer s.newServerFn = scenario3c.newServer
slog.Info("c") slog.Info("scenario3c")
s.pendingReqCh <- c.req s.pendingReqCh <- scenario3c.req
select { select {
case resp := <-c.req.successCh: case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, c.req.errCh) require.Empty(t, scenario3c.req.errCh)
case err := <-c.req.errCh: case err := <-scenario3c.req.errCh:
t.Fatal(err.Error()) t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
@@ -328,25 +298,25 @@ func TestRequestsMultipleLoadedModels(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
// Try to load a model that wont fit // Try to load a model that wont fit
s.newServerFn = d.newServer s.newServerFn = scenario3d.newServer
slog.Info("d") slog.Info("scenario3d")
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 3) require.Len(t, s.loaded, 3)
s.loadedMu.Unlock() s.loadedMu.Unlock()
a.ctxDone() // Won't help since this one isn't big enough to make room scenario3a.ctxDone() // Won't help since this one isn't big enough to make room
time.Sleep(2 * time.Millisecond) time.Sleep(2 * time.Millisecond)
s.pendingReqCh <- d.req s.pendingReqCh <- scenario3d.req
// finish prior request, so new model can load // finish prior request, so new model can load
time.Sleep(6 * time.Millisecond) time.Sleep(6 * time.Millisecond)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 2) require.Len(t, s.loaded, 2)
s.loadedMu.Unlock() s.loadedMu.Unlock()
b.ctxDone() scenario3b.ctxDone()
select { select {
case resp := <-d.req.successCh: case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, d.srv) require.Equal(t, resp.llama, scenario3d.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, d.req.errCh) require.Empty(t, scenario3d.req.errCh)
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -359,19 +329,26 @@ func TestGetRunner(t *testing.T) {
ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond) ctx, done := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer done() defer done()
a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, &api.Duration{Duration: 2 * time.Millisecond}) scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
b := newScenarioRequest(t, ctx, "ollama-model-1b", 10, &api.Duration{Duration: 2 * time.Millisecond}) scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
c := newScenarioRequest(t, ctx, "ollama-model-1c", 10, &api.Duration{Duration: 2 * time.Millisecond}) scenario1b := newScenario(t, ctx, "ollama-model-1b", 10)
scenario1b.req.sessionDuration = &api.Duration{Duration: 0}
scenario1c := newScenario(t, ctx, "ollama-model-1c", 10)
scenario1c.req.sessionDuration = &api.Duration{Duration: 0}
envconfig.MaxQueuedRequests = 1 envconfig.MaxQueuedRequests = 1
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = getGpuFn s.getGpuFn = func() gpu.GpuInfoList {
s.getCpuFn = getCpuFn g := gpu.GpuInfo{Library: "metal"}
s.newServerFn = a.newServer g.TotalMemory = 24 * format.GigaByte
slog.Info("a") g.FreeMemory = 12 * format.GigaByte
successCh1a, errCh1a := s.GetRunner(a.ctx, a.req.model, a.req.opts, a.req.sessionDuration) return []gpu.GpuInfo{g}
}
s.newServerFn = scenario1a.newServer
slog.Info("scenario1a")
successCh1a, errCh1a := s.GetRunner(scenario1a.ctx, scenario1a.req.model, scenario1a.req.opts, scenario1a.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
slog.Info("b") slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(b.ctx, b.req.model, b.req.opts, b.req.sessionDuration) successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b) require.Empty(t, successCh1b)
require.Len(t, errCh1b, 1) require.Len(t, errCh1b, 1)
@@ -380,24 +357,22 @@ func TestGetRunner(t *testing.T) {
s.Run(ctx) s.Run(ctx)
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Empty(t, s.pendingReqCh)
require.Empty(t, errCh1a) require.Empty(t, errCh1a)
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
a.ctxDone() // Set "a" model to idle so it can unload scenario1a.ctxDone()
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
c.req.model.ModelPath = "bad path" scenario1c.req.model.ModelPath = "bad path"
slog.Info("c") slog.Info("scenario1c")
successCh1c, errCh1c := s.GetRunner(c.ctx, c.req.model, c.req.opts, c.req.sessionDuration) successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error // Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(20 * time.Millisecond) // Long enough for the "a" model to expire and unload time.Sleep(5 * time.Millisecond)
require.Empty(t, successCh1c) require.Empty(t, successCh1c)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Empty(t, s.loaded)
@@ -405,7 +380,7 @@ func TestGetRunner(t *testing.T) {
require.Len(t, errCh1c, 1) require.Len(t, errCh1c, 1)
err = <-errCh1c err = <-errCh1c
require.Contains(t, err.Error(), "bad path") require.Contains(t, err.Error(), "bad path")
b.ctxDone() scenario1b.ctxDone()
} }
// TODO - add one scenario that triggers the bogus finished event with positive ref count // TODO - add one scenario that triggers the bogus finished event with positive ref count
@@ -414,7 +389,7 @@ func TestPrematureExpired(t *testing.T) {
defer done() defer done()
// Same model, same request // Same model, same request
scenario1a := newScenarioRequest(t, ctx, "ollama-model-1a", 10, nil) scenario1a := newScenario(t, ctx, "ollama-model-1a", 10)
s := InitScheduler(ctx) s := InitScheduler(ctx)
s.getGpuFn = func() gpu.GpuInfoList { s.getGpuFn = func() gpu.GpuInfoList {
g := gpu.GpuInfo{Library: "metal"} g := gpu.GpuInfo{Library: "metal"}
@@ -436,8 +411,6 @@ func TestPrematureExpired(t *testing.T) {
s.loadedMu.Unlock() s.loadedMu.Unlock()
slog.Info("sending premature expired event now") slog.Info("sending premature expired event now")
s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe s.expiredCh <- resp // Shouldn't happen in real life, but make sure its safe
case err := <-errCh1a:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -473,8 +446,6 @@ func TestUseLoadedRunner(t *testing.T) {
select { select {
case success := <-req.successCh: case success := <-req.successCh:
require.Equal(t, r1, success) require.Equal(t, r1, success)
case err := <-req.errCh:
t.Fatal(err.Error())
case <-ctx.Done(): case <-ctx.Done():
t.Fatal("timeout") t.Fatal("timeout")
} }
@@ -654,7 +625,8 @@ func TestAlreadyCanceled(t *testing.T) {
defer done() defer done()
dctx, done2 := context.WithCancel(ctx) dctx, done2 := context.WithCancel(ctx)
done2() done2()
scenario1a := newScenarioRequest(t, dctx, "ollama-model-1", 10, &api.Duration{Duration: 0}) scenario1a := newScenario(t, dctx, "ollama-model-1", 10)
scenario1a.req.sessionDuration = &api.Duration{Duration: 0}
s := InitScheduler(ctx) s := InitScheduler(ctx)
slog.Info("scenario1a") slog.Info("scenario1a")
s.pendingReqCh <- scenario1a.req s.pendingReqCh <- scenario1a.req
@@ -670,8 +642,8 @@ type mockLlm struct {
pingResp error pingResp error
waitResp error waitResp error
completionResp error completionResp error
embedResp [][]float32 embeddingResp []float64
embedRespErr error embeddingRespErr error
tokenizeResp []int tokenizeResp []int
tokenizeRespErr error tokenizeRespErr error
detokenizeResp string detokenizeResp string
@@ -688,8 +660,8 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
return s.completionResp return s.completionResp
} }
func (s *mockLlm) Embed(ctx context.Context, input []string) ([][]float32, error) { func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
return s.embedResp, s.embedRespErr return s.embeddingResp, s.embeddingRespErr
} }
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) { func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
return s.tokenizeResp, s.tokenizeRespErr return s.tokenizeResp, s.tokenizeRespErr

View File

@@ -1,67 +0,0 @@
{{- if or .Tools .System }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>
{{- if .Tools }}# Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
{{ if .System }}# User Preamble
{{ .System }}
{{- end }}
## Available Tools
Here is a list of tools that you have available to you:
{{- range .Tools }}
```python
def {{ .Function.Name }}(
{{- range $name, $property := .Function.Parameters.Properties }}{{ $name }}: {{ $property.Type }}, {{ end }}) -> List[Dict]:
"""{{ .Function.Description }}
{{- if .Function.Parameters.Properties }}
Args:
{{- range $name, $property := .Function.Parameters.Properties }}
{{ $name }} ({{ $property.Type }}): {{ $property.Description }}
{{- end }}
{{- end }}
"""
pass
```
{{- end }}
{{- else if .System }}{{ .System }}
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- range .Messages }}
{{- if eq .Role "system" }}
{{- continue }}
{{- end }}<|START_OF_TURN_TOKEN|>
{{- if eq .Role "user" }}<|USER_TOKEN|>{{ .Content }}
{{- else if eq .Role "assistant" }}<|CHATBOT_TOKEN|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}
Action: ```json
[
{{- range .ToolCalls }}
{
"tool_name": "{{ .Function.Name }}",
"parameters": {{ .Function.Arguments }}
}
{{- end }}
]```
{{ continue }}
{{ end }}
{{- else if eq .Role "tool" }}<|SYSTEM_TOKEN|><results>
{{ .Content }}</results>
{{- end }}<|END_OF_TURN_TOKEN|>
{{- end }}
{{- if .Tools }}<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```
{{- end }}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,39 +0,0 @@
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|># Safety Preamble
The instructions in this section override those in the task description and style guide sections. Don't answer questions that are harmful or immoral.
# System Preamble
## Basic Rules
You are a powerful conversational AI trained by Cohere to help people. You are augmented by a number of tools, and your job is to use and consume the output of these tools to best help the user. You will see a conversation history between yourself and a user, ending with an utterance from the user. You will then see a specific instruction instructing you what kind of response to generate. When you answer the user's requests, you cite your sources in your answers, according to those instructions.
# User Preamble
You are a knowledgable assistant. You can answer questions and perform tasks.
## Available Tools
Here is a list of tools that you have available to you:
```python
def get_current_weather(format: string, location: string, ) -> List[Dict]:
"""Get the current weather
Args:
format (string): The temperature unit to use. Infer this from the users location.
location (string): The city and state, e.g. San Francisco, CA
"""
pass
```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in Paris?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>
Action: ```json
[
{
"tool_name": "get_current_weather",
"parameters": {"format":"celsius","location":"Paris, France"}
}
]```
<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|><results>
22</results><|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>The current temperature in Paris, France is 22 degrees Celsius.<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|USER_TOKEN|>What's the weather like today in San Francisco and Toronto?<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>Write 'Action:' followed by a json-formatted list of actions that you want to perform in order to produce a good response to the user's last input. You can use any of the supplied tools any number of times, but you should aim to execute the minimum number of necessary actions for the input. You should use the `directly-answer` tool if calling the other tools is unnecessary. The list of actions you want to call should be formatted as a list of json objects, for example:
```json
[
{
"tool_name": title of the tool in the specification,
"parameters": a dict of parameters to input into the tool as they are defined in the specs, or {} if it takes no parameters
}
]```<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>

View File

@@ -1,31 +0,0 @@
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{- if .System }}
{{ .System }}
{{- end }}
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
{{- if .Tools }}
{{ .Tools }}
{{- end }}<|eot_id|>
{{- end }}
{{- range .Messages }}<|start_header_id|>
{{- if or (eq .Role "user") (eq .Role "assistant") (eq .Role "tool") }}{{ .Role }}
{{- end }}<|end_header_id|>
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }} functools[
{{- range .ToolCalls }}{{ "{" }}"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}{{ "}" }}
{{- end }}]
{{- end }}<|eot_id|>
{{- end }}<|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,17 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks.
In addition to plain text responses, you can chose to call one or more of the provided functions.
Use the following rule to decide when to call a function:
* if the response can be generated from your internal knowledge (e.g., as in the case of queries like "What is the capital of Poland?"), do so
* if you need external information that can be obtained by calling one or more of the provided functions, generate a function calls
If you decide to call functions:
* prefix function calls with functools marker (no closing marker required)
* all function calls should be generated in a single JSON list formatted as functools[{"name": [function name], "arguments": [function arguments as JSON]}, ...]
* follow the provided JSON schema. Do not hallucinate arguments or values. Do to blindly copy values from the provided samples
* respect the argument type formatting. E.g., if the type if number and format is float, write value 7 as 7.0
* make sure you pick the right functions that match the user intent
Available functions as JSON spec:
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]<|eot_id|><|start_header_id|><|end_header_id|>You are a knowledgable assistant. You can answer questions and perform tasks.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|> functools[{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]<|eot_id|><|start_header_id|>tool<|end_header_id|>22<|eot_id|><|start_header_id|>assistant<|end_header_id|>The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,43 +0,0 @@
{{- if .Messages }}
{{- if or .System .Tools }}<|start_header_id|>system<|end_header_id|>
{{ .System }}
{{- if .Tools }} You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools>
{{- range .Tools }} {{ .Function }}
{{- end }} </tools>
{{- end }}
{{- end }}<|eot_id|>
{{- range .Messages }}
{{- if ne .Role "system" }}<|start_header_id|>{{ .Role }}<|end_header_id|>
{{ if eq .Role "user" }}{{ .Content }}
{{- else if eq .Role "assistant" }}
{{- if .Content }}{{ .Content }}
{{- else if .ToolCalls }}<tool_call>
{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}
</tool_call>
{{- end }}
{{- else if eq .Role "tool" }}<tool_response>
{{ .Content }}
</tool_response>
{{- end }}<|eot_id|>
{{- end }}
{{- end }}<|start_header_id|>assistant<|end_header_id|>
{{ else }}
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ end }}{{ .Response }}
{{- if .Response }}<|eot_id|>
{{- end }}

View File

@@ -1,24 +0,0 @@
<|start_header_id|>system<|end_header_id|>
You are a knowledgable assistant. You can answer questions and perform tasks. You are provided with function signatures within <tools></tools> XML tags. You may call one or more functions to assist with the user query. Don't make assumptions about what values to plug into functions. For each function call return a json object with function name and arguments within <tool_call></tool_call> XML tags as follows:
<tool_call>
{"name": <function-name>,"arguments": <args-dict>}
</tool_call>
Here are the available tools:
<tools> {"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}} </tools><|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in Paris?<|eot_id|><|start_header_id|>assistant<|end_header_id|>
<tool_call>
{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}
</tool_call><|eot_id|><|start_header_id|>tool<|end_header_id|>
<tool_response>
22
</tool_response><|eot_id|><|start_header_id|>assistant<|end_header_id|>
The current temperature in Paris, France is 22 degrees Celsius.<|eot_id|><|start_header_id|>user<|end_header_id|>
What's the weather like today in San Francisco and Toronto?<|eot_id|><|start_header_id|>assistant<|end_header_id|>

View File

@@ -1,39 +0,0 @@
[
{
"role": "system",
"content": "You are a knowledgable assistant. You can answer questions and perform tasks."
},
{
"role": "user",
"content": "What's the weather like today in Paris?"
},
{
"role": "assistant",
"tool_calls": [
{
"id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"type": "function",
"function": {
"name": "get_current_weather",
"arguments": {
"location": "Paris, France",
"format": "celsius"
}
}
}
]
},
{
"role": "tool",
"tool_call_id": "89a1e453-0bce-4de3-a456-c54bed09c520",
"content": "22"
},
{
"role": "assistant",
"content": "The current temperature in Paris, France is 22 degrees Celsius."
},
{
"role": "user",
"content": "What's the weather like today in San Francisco and Toronto?"
}
]

View File

@@ -1,15 +0,0 @@
{{- range $index, $_ := .Messages }}
{{- if eq .Role "user" }}
{{- if and (eq (len (slice $.Messages $index)) 1) $.Tools }}[AVAILABLE_TOOLS] {{ $.Tools }}[/AVAILABLE_TOOLS]
{{- end }}[INST] {{ if and (eq (len (slice $.Messages $index)) 1) $.System }}{{ $.System }}
{{ end }}{{ .Content }}[/INST]
{{- else if eq .Role "assistant" }}
{{- if .Content }} {{ .Content }}</s>
{{- else if .ToolCalls }}[TOOL_CALLS] [
{{- range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}
{{- end }}]</s>
{{- end }}
{{- else if eq .Role "tool" }}[TOOL_RESULTS] {"content": {{ .Content }}}[/TOOL_RESULTS]
{{- end }}
{{- end }}

View File

@@ -1,3 +0,0 @@
[INST] What's the weather like today in Paris?[/INST][TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]</s>[TOOL_RESULTS] {"content": 22}[/TOOL_RESULTS] The current temperature in Paris, France is 22 degrees Celsius.</s>[AVAILABLE_TOOLS] [{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}][/AVAILABLE_TOOLS][INST] You are a knowledgable assistant. You can answer questions and perform tasks.
What's the weather like today in San Francisco and Toronto?[/INST]

View File

@@ -1,30 +0,0 @@
[
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA"
},
"format": {
"type": "string",
"enum": [
"celsius",
"fahrenheit"
],
"description": "The temperature unit to use. Infer this from the users location."
}
},
"required": [
"location",
"format"
]
}
}
}
]

View File

@@ -1,45 +0,0 @@
{{- if .System }}{{ .System }}
{{ end }}
{{- range $i, $_ := .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- if and $.Tools (le (len (slice $.Messages $i)) 2) }}
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
{{ $.Tools }}
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
{{ .Content }}
[END OF QUERY]
{{ else }}
{{ .Content }}
{{ end }}
{{- else if .ToolCalls }}### Response:
{"tool_calls": [{{ range .ToolCalls }}{"name": "{{ .Function.Name }}", "arguments": {{ .Function.Arguments }}}{{ end }}]}
<|EOT|>
{{ else if eq .Role "assistant" }}### Response:
{{ .Content }}
<|EOT|>
{{ end }}
{{- end }}### Response:

View File

@@ -1,40 +0,0 @@
You are a knowledgable assistant. You can answer questions and perform tasks.
### Instruction:
What's the weather like today in Paris?
### Response:
{"tool_calls": [{"name": "get_current_weather", "arguments": {"format":"celsius","location":"Paris, France"}}]}
<|EOT|>
### Response:
The current temperature in Paris, France is 22 degrees Celsius.
<|EOT|>
### Instruction:
[BEGIN OF TASK INSTRUCTION]
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the functions can be used, point it out and refuse to answer.
If the given question lacks the parameters required by the function, also point it out.
[END OF TASK INSTRUCTION]
[BEGIN OF AVAILABLE TOOLS]
[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather","parameters":{"type":"object","required":["location","format"],"properties":{"format":{"type":"string","description":"The temperature unit to use. Infer this from the users location.","enum":["celsius","fahrenheit"]},"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"}}}}}]
[END OF AVAILABLE TOOLS]
[BEGIN OF FORMAT INSTRUCTION]
The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
[BEGIN OF QUERY]
What's the weather like today in San Francisco and Toronto?
[END OF QUERY]
### Response:

View File

@@ -254,7 +254,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
// retry uploading to the redirect URL // retry uploading to the redirect URL
for try := range maxRetries { for try := range maxRetries {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, &registryOptions{}) err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
return err return err

View File

@@ -1 +1,8 @@
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message> {{- if .Messages }}
{{- if .System }}<start_system>{{ .System }}<end_message>
{{- end }}
{{- range .Messages }}<start_{{ .Role }}>{{ .Content }}<end_message>
{{- end }}<start_assistant>
{{- else }}
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>
{{- end }}

View File

@@ -1,8 +0,0 @@
{
"stop": [
"<start_system>",
"<end_message>",
"<start_user>",
"<start_assistant>"
]
}

View File

@@ -1,3 +1,14 @@
{{- if .Messages }}
{{- if .System }}{{ .System }}
{{- end }}
{{- range .Messages }}
{{- if eq .Role "user" }}### Instruction:
{{- else if eq .Role "assistant" }}### Response:
{{- end }}
{{ .Content }}
{{ end }}### Response:
{{ else }}
{{ if .System }}{{ .System }} {{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction: {{ end }}{{ if .Prompt }}### Instruction:
@@ -5,4 +16,4 @@
{{ end }}### Response: {{ end }}### Response:
{{ .Response }} {{ .Response }}
{{- end }}

View File

@@ -1,6 +0,0 @@
{
"stop": [
"### Instruction:",
"### Response"
]
}

View File

@@ -1,6 +1,15 @@
{{- if .Messages }}
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}
{{- range .Messages }}<|im_start|>{{ .Role }}
{{ .Content }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ else }}
{{ if .System }}<|im_start|>system {{ if .System }}<|im_start|>system
{{ .System }}<|im_end|> {{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user {{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|> {{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant {{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|> {{ .Response }}<|im_end|>
{{- end }}

View File

@@ -1,6 +0,0 @@
{
"stop": [
"<|im_start|>",
"<|im_end|>"
]
}

View File

@@ -1,6 +1,17 @@
{{- if .Messages }}
{{- if .System }}System: {{ .System }}
{{ end }}
{{- range .Messages }}
{{- if eq .Role "user" }}User:
{{- else if eq .Role "assistant" }}Assistant:
{{- end }} {{ .Content }}
{{ end }}Assistant:
{{- else }}
{{ if .System }}System: {{ .System }} {{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }} {{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }} {{ end }}Assistant: <|begin_of_text|>{{ .Response }}
{{- end }}

Some files were not shown because too many files have changed in this diff Show More