Compare commits

..

7 Commits

Author SHA1 Message Date
Roy Han
5f0403d208 Isolated Deletions 2024-05-31 17:40:11 -07:00
Roy Han
5a3cb1064a Clean Up Delete Tests 2024-05-31 16:44:54 -07:00
Roy Han
77487ada72 Err Check 2024-05-31 13:12:26 -07:00
Roy Han
a946b6f020 Adjust Response and Blob Check 2024-05-31 13:08:59 -07:00
Roy Han
c62df6b3bf Check Blob 2024-05-31 12:07:52 -07:00
Roy Han
e8788ae8dd Specify DNE error 2024-05-31 09:45:47 -07:00
Roy Han
8774e5d6a9 Deletion Unit Test 2024-05-30 16:44:17 -07:00
99 changed files with 818 additions and 2167 deletions

View File

@@ -34,13 +34,13 @@ jobs:
git diff-tree -r --no-commit-id --name-only \ git diff-tree -r --no-commit-id --name-only \
$(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \ $(git merge-base ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }}) \
${{ github.event.pull_request.head.sha }} \ ${{ github.event.pull_request.head.sha }} \
| xargs python3 -c "import sys; from pathlib import Path; print(any(Path(x).match(glob) for x in sys.argv[1:] for glob in '$*'.split(' ')))" | xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
} }
{ {
echo GENERATE=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') echo GENERATE=$(changed llm/)
echo GENERATE_CUDA=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') echo GENERATE_CUDA=$(changed llm/)
echo GENERATE_ROCM=$(changed 'llm/llama.cpp' 'llm/patches/**' 'llm/ext_server/**' 'llm/generate/**') echo GENERATE_ROCM=$(changed llm/)
} >>$GITHUB_OUTPUT } >>$GITHUB_OUTPUT
generate: generate:
@@ -269,9 +269,9 @@ jobs:
mkdir -p llm/build/darwin/$ARCH/stub/bin mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }} if: ${{ startsWith(matrix.os, 'macos-') }}
- uses: golangci/golangci-lint-action@v6 - uses: golangci/golangci-lint-action@v4
with: with:
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }} args: --timeout 8m0s -v
test: test:
strategy: strategy:
matrix: matrix:
@@ -287,8 +287,6 @@ jobs:
GOARCH: ${{ matrix.arch }} GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1' CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: 'static' OLLAMA_CPU_TARGET: 'static'
OLLAMA_SKIP_CPU_GENERATE: '1'
OLLAMA_SKIP_METAL_GENERATE: '1'
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:

View File

@@ -9,26 +9,9 @@ linters:
- contextcheck - contextcheck
- exportloopref - exportloopref
- gocheckcompilerdirectives - gocheckcompilerdirectives
# conditionally enable this on linux/macos # FIXME: for some reason this errors on windows
# - gofmt # - gofmt
# - goimports # - goimports
- intrange
- misspell - misspell
- nilerr - nilerr
- nolintlint
- nosprintfhostport
- testifylint
- unconvert
- unused - unused
- wastedassign
- whitespace
- usestdlibvars
severity:
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
- usestdlibvars
severity: info

View File

@@ -6,7 +6,7 @@
[![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama) [![Discord](https://dcbadge.vercel.app/api/server/ollama?style=flat&compact=true)](https://discord.gg/ollama)
Get up and running with large language models. Get up and running with large language models locally.
### macOS ### macOS
@@ -285,7 +285,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends) - [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama) - [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS) - [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
### Terminal ### Terminal
@@ -308,7 +307,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [ShellOracle](https://github.com/djcopley/ShellOracle) - [ShellOracle](https://github.com/djcopley/ShellOracle)
- [tlm](https://github.com/yusufcanb/tlm) - [tlm](https://github.com/yusufcanb/tlm)
- [podman-ollama](https://github.com/ericcurtin/podman-ollama) - [podman-ollama](https://github.com/ericcurtin/podman-ollama)
- [gollama](https://github.com/sammcj/gollama)
### Database ### Database
@@ -326,7 +324,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa) - [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example) - [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java) - [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html) - [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
- [LiteLLM](https://github.com/BerriAI/litellm) - [LiteLLM](https://github.com/BerriAI/litellm)
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp) - [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
@@ -349,7 +346,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama) - [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama) - [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
- [LlamaScript](https://github.com/Project-Llama/llamascript) - [LlamaScript](https://github.com/Project-Llama/llamascript)
### Mobile ### Mobile
- [Enchanted](https://github.com/AugustDev/enchanted) - [Enchanted](https://github.com/AugustDev/enchanted)
@@ -382,9 +378,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support) - [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation) - [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities. - [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
### Supported backends
### Supported backends
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov. - [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.

View File

@@ -355,8 +355,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
} }
// List running models. // List running models.
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) { func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
var lr ProcessResponse var lr ListResponse
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil { if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
return nil, err return nil, err
} }

View File

@@ -282,33 +282,19 @@ type PushRequest struct {
// ListResponse is the response from [Client.List]. // ListResponse is the response from [Client.List].
type ListResponse struct { type ListResponse struct {
Models []ListModelResponse `json:"models"` Models []ModelResponse `json:"models"`
} }
// ProcessResponse is the response from [Client.Process]. // ModelResponse is a single model description in [ListResponse].
type ProcessResponse struct { type ModelResponse struct {
Models []ProcessModelResponse `json:"models"`
}
// ListModelResponse is a single model description in [ListResponse].
type ListModelResponse struct {
Name string `json:"name"` Name string `json:"name"`
Model string `json:"model"` Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"` ModifiedAt time.Time `json:"modified_at,omitempty"`
Size int64 `json:"size"` Size int64 `json:"size"`
Digest string `json:"digest"` Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
} ExpiresAt time.Time `json:"expires_at,omitempty"`
SizeVRAM int64 `json:"size_vram,omitempty"`
// ProcessModelResponse is a single model description in [ProcessResponse].
type ProcessModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
Size int64 `json:"size"`
Digest string `json:"digest"`
Details ModelDetails `json:"details,omitempty"`
ExpiresAt time.Time `json:"expires_at"`
SizeVRAM int64 `json:"size_vram"`
} }
type TokenResponse struct { type TokenResponse struct {
@@ -320,7 +306,7 @@ type GenerateResponse struct {
// Model is the model name that generated the response. // Model is the model name that generated the response.
Model string `json:"model"` Model string `json:"model"`
// CreatedAt is the timestamp of the response. //CreatedAt is the timestamp of the response.
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
// Response is the textual response itself. // Response is the textual response itself.

View File

@@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
}, },
{ {
"positive duration", "positive duration",
42 * time.Second, time.Duration(42 * time.Second),
42 * time.Second, time.Duration(42 * time.Second),
}, },
{ {
"another positive duration", "another positive duration",
42 * time.Minute, time.Duration(42 * time.Minute),
42 * time.Minute, time.Duration(42 * time.Minute),
}, },
{ {
"zero duration", "zero duration",

View File

@@ -69,6 +69,7 @@ func init() {
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err)) slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
} }
} }
} else if runtime.GOOS == "darwin" { } else if runtime.GOOS == "darwin" {
// TODO // TODO
AppName += ".app" AppName += ".app"

View File

@@ -15,7 +15,7 @@ import (
) )
func getCLIFullPath(command string) string { func getCLIFullPath(command string) string {
var cmdPath string cmdPath := ""
appExe, err := os.Executable() appExe, err := os.Executable()
if err == nil { if err == nil {
cmdPath = filepath.Join(filepath.Dir(appExe), command) cmdPath = filepath.Join(filepath.Dir(appExe), command)
@@ -65,6 +65,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
if err != nil { if err != nil {
if !errors.Is(err, os.ErrNotExist) { if !errors.Is(err, os.ErrNotExist) {
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err) return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
} }
if err := os.MkdirAll(logDir, 0o755); err != nil { if err := os.MkdirAll(logDir, 0o755); err != nil {

View File

@@ -24,8 +24,7 @@ func terminate(cmd *exec.Cmd) error {
if err != nil { if err != nil {
return err return err
} }
//nolint:errcheck defer dll.Release() // nolint: errcheck
defer dll.Release()
pid := cmd.Process.Pid pid := cmd.Process.Pid
@@ -74,8 +73,7 @@ func isProcessExited(pid int) (bool, error) {
if err != nil { if err != nil {
return false, fmt.Errorf("failed to open process: %v", err) return false, fmt.Errorf("failed to open process: %v", err)
} }
//nolint:errcheck defer windows.CloseHandle(hProcess) // nolint: errcheck
defer windows.CloseHandle(hProcess)
var exitCode uint32 var exitCode uint32
err = windows.GetExitCodeProcess(hProcess, &exitCode) err = windows.GetExitCodeProcess(hProcess, &exitCode)

View File

@@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode == http.StatusNoContent { if resp.StatusCode == 204 {
slog.Debug("check update response 204 (current version is up to date)") slog.Debug("check update response 204 (current version is up to date)")
return false, updateResp return false, updateResp
} }
@@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
slog.Warn(fmt.Sprintf("failed to read body response: %s", err)) slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != 200 {
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body))) slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
return false, updateResp return false, updateResp
} }
@@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
if err != nil { if err != nil {
return fmt.Errorf("error checking update: %w", err) return fmt.Errorf("error checking update: %w", err)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != 200 {
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode) return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
} }
resp.Body.Close() resp.Body.Close()

View File

@@ -4,5 +4,5 @@ write-host "Welcome to Ollama!"
write-host "" write-host ""
write-host "Run your first model:" write-host "Run your first model:"
write-host "" write-host ""
write-host "`tollama run llama3" write-host "`tollama run llama2"
write-host "" write-host ""

View File

@@ -29,6 +29,7 @@ func GetID() string {
initStore() initStore()
} }
return store.ID return store.ID
} }
func GetFirstTimeRun() bool { func GetFirstTimeRun() bool {

View File

@@ -47,6 +47,7 @@ func nativeLoop() {
default: default:
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
} }
} }
} }
@@ -159,8 +160,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
lResult, _, _ = pDefWindowProc.Call( lResult, _, _ = pDefWindowProc.Call(
uintptr(hWnd), uintptr(hWnd),
uintptr(message), uintptr(message),
wParam, uintptr(wParam),
lParam, uintptr(lParam),
) )
} }
return return

View File

@@ -186,7 +186,7 @@ func (t *winTray) initInstance() error {
t.muNID.Lock() t.muNID.Lock()
defer t.muNID.Unlock() defer t.muNID.Unlock()
t.nid = &notifyIconData{ t.nid = &notifyIconData{
Wnd: t.window, Wnd: windows.Handle(t.window),
ID: 100, ID: 100,
Flags: NIF_MESSAGE, Flags: NIF_MESSAGE,
CallbackMessage: t.wmSystrayMessage, CallbackMessage: t.wmSystrayMessage,
@@ -197,6 +197,7 @@ func (t *winTray) initInstance() error {
} }
func (t *winTray) createMenu() error { func (t *winTray) createMenu() error {
menuHandle, _, err := pCreatePopupMenu.Call() menuHandle, _, err := pCreatePopupMenu.Call()
if menuHandle == 0 { if menuHandle == 0 {
return err return err
@@ -245,7 +246,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
mi := menuItemInfo{ mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE, Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
Type: MFT_STRING, Type: MFT_STRING,
ID: menuItemId, ID: uint32(menuItemId),
TypeData: titlePtr, TypeData: titlePtr,
Cch: uint32(len(title)), Cch: uint32(len(title)),
} }
@@ -301,10 +302,11 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
} }
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error { func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
mi := menuItemInfo{ mi := menuItemInfo{
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE, Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
Type: MFT_SEPARATOR, Type: MFT_SEPARATOR,
ID: menuItemId, ID: uint32(menuItemId),
} }
mi.Size = uint32(unsafe.Sizeof(mi)) mi.Size = uint32(unsafe.Sizeof(mi))
@@ -424,6 +426,7 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
// Loads an image from file and shows it in tray. // Loads an image from file and shows it in tray.
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx // Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
func (t *winTray) setIcon(src string) error { func (t *winTray) setIcon(src string) error {
h, err := t.loadIconFrom(src) h, err := t.loadIconFrom(src)
if err != nil { if err != nil {
return err return err
@@ -441,6 +444,7 @@ func (t *winTray) setIcon(src string) error {
// Loads an image from file to be shown in tray or menu item. // Loads an image from file to be shown in tray or menu item.
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx // LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) { func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
// Save and reuse handles of loaded images // Save and reuse handles of loaded images
t.muLoadedImages.RLock() t.muLoadedImages.RLock()
h, ok := t.loadedImages[src] h, ok := t.loadedImages[src]

View File

@@ -20,7 +20,6 @@ import (
"path/filepath" "path/filepath"
"regexp" "regexp"
"runtime" "runtime"
"slices"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@@ -30,6 +29,7 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/term" "golang.org/x/term"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
@@ -746,6 +746,7 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
if wordWrap && termWidth >= 10 { if wordWrap && termWidth >= 10 {
for _, ch := range content { for _, ch := range content {
if state.lineLength+1 > termWidth-5 { if state.lineLength+1 > termWidth-5 {
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 { if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch) fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = "" state.wordBuffer = ""
@@ -754,11 +755,7 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
} }
// backtrack the length of the last word and clear to the end of the line // backtrack the length of the last word and clear to the end of the line
a := runewidth.StringWidth(state.wordBuffer) fmt.Printf("\x1b[%dD\x1b[K\n", runewidth.StringWidth(state.wordBuffer))
if a > 0 {
fmt.Printf("\x1b[%dD", a)
}
fmt.Printf("\x1b[K\n")
fmt.Printf("%s%c", state.wordBuffer, ch) fmt.Printf("%s%c", state.wordBuffer, ch)
chWidth := runewidth.RuneWidth(ch) chWidth := runewidth.RuneWidth(ch)
@@ -1029,6 +1026,24 @@ func initializeKeypair() error {
return nil return nil
} }
//nolint:unused
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error { func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
@@ -1236,9 +1251,6 @@ func NewCLI() *cobra.Command {
envVars["OLLAMA_NOPRUNE"], envVars["OLLAMA_NOPRUNE"],
envVars["OLLAMA_ORIGINS"], envVars["OLLAMA_ORIGINS"],
envVars["OLLAMA_TMPDIR"], envVars["OLLAMA_TMPDIR"],
envVars["OLLAMA_FLASH_ATTENTION"],
envVars["OLLAMA_LLM_LIBRARY"],
envVars["OLLAMA_MAX_VRAM"],
}) })
default: default:
appendEnvDocs(cmd, envs) appendEnvDocs(cmd, envs)

View File

@@ -8,11 +8,11 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"regexp" "regexp"
"slices"
"sort" "sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"

View File

@@ -6,7 +6,6 @@ import (
"text/template" "text/template"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
) )
@@ -86,11 +85,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err := template.New("").Parse(expectedModelfile) tmpl, err := template.New("").Parse(expectedModelfile)
require.NoError(t, err) assert.Nil(t, err)
var buf bytes.Buffer var buf bytes.Buffer
err = tmpl.Execute(&buf, opts) err = tmpl.Execute(&buf, opts)
require.NoError(t, err) assert.Nil(t, err)
assert.Equal(t, buf.String(), mf) assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark" opts.ParentModel = "horseshark"
@@ -108,10 +107,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
` `
tmpl, err = template.New("").Parse(expectedModelfile) tmpl, err = template.New("").Parse(expectedModelfile)
require.NoError(t, err) assert.Nil(t, err)
var parentBuf bytes.Buffer var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts) err = tmpl.Execute(&parentBuf, opts)
require.NoError(t, err) assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf) assert.Equal(t, parentBuf.String(), mf)
} }

View File

@@ -1,27 +0,0 @@
//go:build darwin || windows
package cmd
import (
"context"
"errors"
"time"
"github.com/ollama/ollama/api"
)
func waitForServer(ctx context.Context, client *api.Client) error {
// wait for the server to start
timeout := time.After(5 * time.Second)
tick := time.Tick(500 * time.Millisecond)
for {
select {
case <-timeout:
return errors.New("timed out waiting for server to start")
case <-tick:
if err := client.Heartbeat(ctx); err == nil {
return nil // server has started
}
}
}
}

View File

@@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
if params.VocabSize > len(v.Tokens) { if params.VocabSize > len(v.Tokens) {
missingTokens := params.VocabSize - len(v.Tokens) missingTokens := params.VocabSize - len(v.Tokens)
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens)) slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
for cnt := range missingTokens { for cnt := 0; cnt < missingTokens; cnt++ {
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1)) v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
v.Scores = append(v.Scores, -1) v.Scores = append(v.Scores, -1)
v.Types = append(v.Types, tokenTypeUserDefined) v.Types = append(v.Types, tokenTypeUserDefined)

View File

@@ -35,6 +35,7 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
f32s = append(f32s, t...) f32s = append(f32s, t...)
} }
return f32s, nil return f32s, nil
} }

View File

@@ -119,12 +119,11 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([
} }
var heads int var heads int
switch { if strings.HasSuffix(name, "attn_q.weight") {
case strings.HasSuffix(name, "attn_q.weight"):
heads = params.AttentionHeads heads = params.AttentionHeads
case strings.HasSuffix(name, "attn_k.weight"): } else if strings.HasSuffix(name, "attn_k.weight") {
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads) heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
default: } else {
return nil, fmt.Errorf("unknown tensor name: %s", name) return nil, fmt.Errorf("unknown tensor name: %s", name)
} }

View File

@@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
Name: name, Name: name,
Kind: kind, Kind: kind,
Offset: offset, Offset: offset,
Shape: shape, Shape: shape[:],
} }
t.WriterTo = safetensorWriterTo{ t.WriterTo = safetensorWriterTo{

View File

@@ -85,8 +85,11 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e
sha256sum := sha256.New() sha256sum := sha256.New()
for _, pt := range t.PreTokenizer.PreTokenizers { for _, pt := range t.PreTokenizer.PreTokenizers {
if pt.Type == "Split" && pt.Pattern.Regex != "" { switch pt.Type {
sha256sum.Write([]byte(pt.Pattern.Regex)) case "Split":
if pt.Pattern.Regex != "" {
sha256sum.Write([]byte(pt.Pattern.Regex))
}
} }
} }

View File

@@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
Name: ggufName, Name: ggufName,
Kind: kind, Kind: kind,
Offset: offset, // calculate the offset Offset: offset, // calculate the offset
Shape: shape, Shape: shape[:],
} }
tensor.WriterTo = torchWriterTo{ tensor.WriterTo = torchWriterTo{
@@ -104,6 +104,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
} }
return tensors, nil return tensors, nil
} }
func getAltParams(dirpath string) (*Params, error) { func getAltParams(dirpath string) (*Params, error) {

View File

@@ -12,7 +12,6 @@
- [Pull a Model](#pull-a-model) - [Pull a Model](#pull-a-model)
- [Push a Model](#push-a-model) - [Push a Model](#push-a-model)
- [Generate Embeddings](#generate-embeddings) - [Generate Embeddings](#generate-embeddings)
- [List Running Models](#list-running-models)
## Conventions ## Conventions
@@ -1036,48 +1035,3 @@ curl http://localhost:11434/api/embeddings -d '{
] ]
} }
``` ```
## List Running Models
```shell
GET /api/ps
```
List models that are currently loaded into memory.
\* If a model is loaded completely into system memory, `size_vram` is omitted from the response.
#### Examples
### Request
```shell
curl http://localhost:11434/api/ps
```
#### Response
A single JSON object will be returned.
```json
{
"models": [
{
"name": "mistral:latest",
"model": "mistral:latest",
"size": 5137025024,
"digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8",
"details": {
"parent_model": "",
"format": "gguf",
"family": "llama",
"families": [
"llama"
],
"parameter_size": "7.2B",
"quantization_level": "Q4_0"
},
"expires_at": "2024-06-04T14:38:31.83753-07:00",
"size_vram": 5137025024
}
]
}
```

View File

@@ -76,7 +76,6 @@ Make sure you've set up the container runtime first as described in [docker.md](
Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem Sometimes the container runtime can have difficulties initializing the GPU. When you check the server logs, this can show up as various error codes, such as "3" (not initialized), "46" (device unavailable), "100" (no device), "999" (unknown), or others. The following troubleshooting techniques may help resolve the problem
- Is the container runtime working? Try `docker run --gpus all ubuntu nvidia-smi` - if this doesn't work, Ollama wont be able to see your NVIDIA GPU.
- Is the uvm driver not loaded? `sudo nvidia-modprobe -u` - Is the uvm driver not loaded? `sudo nvidia-modprobe -u`
- Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm` - Try reloading the nvidia_uvm driver - `sudo rmmod nvidia_uvm` then `sudo modprobe nvidia_uvm`
- Try rebooting - Try rebooting

View File

@@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data)
``` ```
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb` It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
We also need to pull embedding model: `ollama pull nomic-embed-text`
```python ```python
from langchain.embeddings import OllamaEmbeddings from langchain.embeddings import OllamaEmbeddings
from langchain.vectorstores import Chroma from langchain.vectorstores import Chroma
@@ -68,8 +68,7 @@ The next thing is to send the question and the relevant parts of the docs to the
```python ```python
from langchain.chains import RetrievalQA from langchain.chains import RetrievalQA
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever()) qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
res = qachain.invoke({"query": question}) qachain.invoke({"query": question})
print(res['result'])
``` ```
The answer received from this chain was: The answer received from this chain was:

View File

@@ -3,7 +3,6 @@ package envconfig
import ( import (
"fmt" "fmt"
"log/slog" "log/slog"
"net"
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
@@ -52,16 +51,16 @@ func AsMap() map[string]EnvVar {
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"}, "OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
"OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"}, "OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"}, "OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"}, "OLLAMA_LLM_LIBRARY": {"OLLAMA_ORIGINS", LLMLibrary, ""},
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"}, "OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"}, "OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"}, "OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, ""},
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"}, "OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"}, "OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, ""},
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"}, "OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
} }
} }
@@ -90,7 +89,6 @@ func init() {
NumParallel = 1 NumParallel = 1
MaxRunners = 1 MaxRunners = 1
MaxQueuedRequests = 512 MaxQueuedRequests = 512
FlashAttention = true
LoadConfig() LoadConfig()
} }
@@ -128,7 +126,7 @@ func LoadConfig() {
var paths []string var paths []string
for _, root := range []string{filepath.Dir(appExe), cwd} { for _, root := range []string{filepath.Dir(appExe), cwd} {
paths = append(paths, paths = append(paths,
root, filepath.Join(root),
filepath.Join(root, "windows-"+runtime.GOARCH), filepath.Join(root, "windows-"+runtime.GOARCH),
filepath.Join(root, "dist", "windows-"+runtime.GOARCH), filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
) )
@@ -186,17 +184,11 @@ func LoadConfig() {
AllowOrigins = append(AllowOrigins, AllowOrigins = append(AllowOrigins,
fmt.Sprintf("http://%s", allowOrigin), fmt.Sprintf("http://%s", allowOrigin),
fmt.Sprintf("https://%s", allowOrigin), fmt.Sprintf("https://%s", allowOrigin),
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")), fmt.Sprintf("http://%s:*", allowOrigin),
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")), fmt.Sprintf("https://%s:*", allowOrigin),
) )
} }
AllowOrigins = append(AllowOrigins,
"app://*",
"file://*",
"tauri://*",
)
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS") maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
if maxRunners != "" { if maxRunners != "" {
m, err := strconv.Atoi(maxRunners) m, err := strconv.Atoi(maxRunners)

View File

@@ -5,6 +5,7 @@ import (
) )
func TestHumanNumber(t *testing.T) { func TestHumanNumber(t *testing.T) {
type testCase struct { type testCase struct {
input uint64 input uint64
expected string expected string

1
go.mod
View File

@@ -16,7 +16,6 @@ require (
) )
require ( require (
github.com/agnivade/levenshtein v1.1.1
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1 github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/nlpodyssey/gopickle v0.3.0 github.com/nlpodyssey/gopickle v0.3.0

6
go.sum
View File

@@ -4,14 +4,10 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8= gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw= github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ=
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8= github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0= github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4= github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
@@ -40,8 +36,6 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc= github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ= github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=

View File

@@ -65,7 +65,7 @@ func AMDGetGPUInfo() []GpuInfo {
slog.Debug("detected hip devices", "count", count) slog.Debug("detected hip devices", "count", count)
// TODO how to determine the underlying device ID when visible devices is causing this to subset? // TODO how to determine the underlying device ID when visible devices is causing this to subset?
for i := range count { for i := 0; i < count; i++ {
err = hl.HipSetDevice(i) err = hl.HipSetDevice(i)
if err != nil { if err != nil {
slog.Warn("set device", "id", i, "error", err) slog.Warn("set device", "id", i, "error", err)

View File

@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
if err == nil { if err == nil {
pid, err := strconv.Atoi(string(raw)) pid, err := strconv.Atoi(string(raw))
if err == nil { if err == nil {
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) { if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
// Another running ollama, ignore this tmpdir // Another running ollama, ignore this tmpdir
continue continue
} }

View File

@@ -18,4 +18,5 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
ids = append(ids, info.ID) ids = append(ids, info.ID)
} }
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",") return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
} }

View File

@@ -16,12 +16,13 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/envconfig"
) )
type handles struct { type handles struct {
@@ -104,6 +105,8 @@ func initGPUHandles() *handles {
var cudartMgmtPatterns []string var cudartMgmtPatterns []string
var nvcudaMgmtName string var nvcudaMgmtName string
var nvcudaMgmtPatterns []string var nvcudaMgmtPatterns []string
var oneapiMgmtName string
var oneapiMgmtPatterns []string
tmpDir, _ := PayloadsDir() tmpDir, _ := PayloadsDir()
switch runtime.GOOS { switch runtime.GOOS {
@@ -115,6 +118,8 @@ func initGPUHandles() *handles {
// Aligned with driver, we can't carry as payloads // Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "nvcuda.dll" nvcudaMgmtName = "nvcuda.dll"
nvcudaMgmtPatterns = NvcudaWindowsGlobs nvcudaMgmtPatterns = NvcudaWindowsGlobs
oneapiMgmtName = "ze_intel_gpu64.dll"
oneapiMgmtPatterns = OneapiWindowsGlobs
case "linux": case "linux":
cudartMgmtName = "libcudart.so*" cudartMgmtName = "libcudart.so*"
if tmpDir != "" { if tmpDir != "" {
@@ -125,6 +130,8 @@ func initGPUHandles() *handles {
// Aligned with driver, we can't carry as payloads // Aligned with driver, we can't carry as payloads
nvcudaMgmtName = "libcuda.so*" nvcudaMgmtName = "libcuda.so*"
nvcudaMgmtPatterns = NvcudaLinuxGlobs nvcudaMgmtPatterns = NvcudaLinuxGlobs
oneapiMgmtName = "libze_intel_gpu.so"
oneapiMgmtPatterns = OneapiLinuxGlobs
default: default:
return gpuHandles return gpuHandles
} }
@@ -152,6 +159,17 @@ func initGPUHandles() *handles {
} }
} }
oneapiLibPaths := FindGPULibs(oneapiMgmtName, oneapiMgmtPatterns)
if len(oneapiLibPaths) > 0 {
deviceCount, oneapi, libPath := LoadOneapiMgmt(oneapiLibPaths)
if oneapi != nil {
slog.Debug("detected Intel GPUs", "library", libPath, "count", deviceCount)
gpuHandles.oneapi = oneapi
gpuHandles.deviceCount = deviceCount
return gpuHandles
}
}
return gpuHandles return gpuHandles
} }
@@ -187,7 +205,7 @@ func GetGPUInfo() GpuInfoList {
resp := []GpuInfo{} resp := []GpuInfo{}
// NVIDIA first // NVIDIA first
for i := range gpuHandles.deviceCount { for i := 0; i < gpuHandles.deviceCount; i++ {
// TODO once we support CPU compilation variants of GPU libraries refine this... // TODO once we support CPU compilation variants of GPU libraries refine this...
if cpuVariant == "" && runtime.GOARCH == "amd64" { if cpuVariant == "" && runtime.GOARCH == "amd64" {
continue continue
@@ -221,12 +239,24 @@ func GetGPUInfo() GpuInfoList {
gpuInfo.MinimumMemory = cudaMinimumMemory gpuInfo.MinimumMemory = cudaMinimumMemory
gpuInfo.DependencyPath = depPath gpuInfo.DependencyPath = depPath
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0]) gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
gpuInfo.DriverMajor = driverMajor gpuInfo.DriverMajor = int(driverMajor)
gpuInfo.DriverMinor = driverMinor gpuInfo.DriverMinor = int(driverMinor)
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does... // TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
resp = append(resp, gpuInfo) resp = append(resp, gpuInfo)
} }
if gpuHandles.oneapi != nil {
gpuInfo := GpuInfo{
Library: "oneapi",
}
C.oneapi_check_vram(*gpuHandles.oneapi, &memInfo)
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
memInfo.free = C.uint64_t(totalFreeMem)
gpuInfo.TotalMemory = uint64(memInfo.total)
gpuInfo.FreeMemory = uint64(memInfo.free)
gpuInfo.ID = strconv.Itoa(i)
resp = append(resp, gpuInfo)
}
} }
// Then AMD // Then AMD

View File

@@ -5,12 +5,11 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestBasicGetGPUInfo(t *testing.T) { func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo() info := GetGPUInfo()
assert.NotEmpty(t, len(info)) assert.Greater(t, len(info), 0)
assert.Contains(t, "cuda rocm cpu metal", info[0].Library) assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
if info[0].Library != "cpu" { if info[0].Library != "cpu" {
assert.Greater(t, info[0].TotalMemory, uint64(0)) assert.Greater(t, info[0].TotalMemory, uint64(0))
@@ -20,7 +19,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
func TestCPUMemInfo(t *testing.T) { func TestCPUMemInfo(t *testing.T) {
info, err := GetCPUMem() info, err := GetCPUMem()
require.NoError(t, err) assert.NoError(t, err)
switch runtime.GOOS { switch runtime.GOOS {
case "darwin": case "darwin":
t.Skip("CPU memory not populated on darwin") t.Skip("CPU memory not populated on darwin")

View File

@@ -140,6 +140,7 @@ struct server_slot {
std::vector<llama_token> cache_tokens; std::vector<llama_token> cache_tokens;
std::vector<completion_token_output> generated_token_probs; std::vector<completion_token_output> generated_token_probs;
bool infill = false;
bool embedding = false; bool embedding = false;
bool has_next_token = true; bool has_next_token = true;
bool truncated = false; bool truncated = false;
@@ -186,6 +187,7 @@ struct server_slot {
n_past = 0; n_past = 0;
n_sent_text = 0; n_sent_text = 0;
n_sent_token_probs = 0; n_sent_token_probs = 0;
infill = false;
ga_i = 0; ga_i = 0;
n_past_se = 0; n_past_se = 0;
@@ -598,6 +600,16 @@ struct llama_server_context
slot->params.n_predict = slot->n_predict; slot->params.n_predict = slot->n_predict;
} }
// infill
if (data.count("input_prefix") != 0)
{
slot->params.input_prefix = data["input_prefix"];
}
else
{
slot->params.input_prefix = "";
}
if (data.count("input_suffix") != 0) if (data.count("input_suffix") != 0)
{ {
slot->params.input_suffix = data["input_suffix"]; slot->params.input_suffix = data["input_suffix"];
@@ -835,7 +847,7 @@ struct llama_server_context
system_tokens.clear(); system_tokens.clear();
if (!system_prompt.empty()) { if (!system_prompt.empty()) {
system_tokens = ::llama_tokenize(ctx, system_prompt, true); system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
llama_batch_clear(batch); llama_batch_clear(batch);
@@ -885,6 +897,15 @@ struct llama_server_context
system_need_update = true; system_need_update = true;
} }
void system_prompt_process(const json &sys_props) {
system_prompt = sys_props.value("prompt", "");
name_user = sys_props.value("anti_prompt", "");
name_assistant = sys_props.value("assistant_name", "");
system_prompt_notify();
}
static size_t find_stopping_strings(const std::string &text, const size_t last_token_size, static size_t find_stopping_strings(const std::string &text, const size_t last_token_size,
const stop_type type, server_slot &slot) const stop_type type, server_slot &slot)
{ {
@@ -1242,12 +1263,13 @@ struct llama_server_context
queue_results.send(res); queue_results.send(res);
} }
void request_completion(int task_id, json data, bool embedding, int multitask_id) void request_completion(int task_id, json data, bool infill, bool embedding, int multitask_id)
{ {
task_server task; task_server task;
task.id = task_id; task.id = task_id;
task.target_id = 0; task.target_id = 0;
task.data = std::move(data); task.data = std::move(data);
task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = TASK_TYPE_COMPLETION; task.type = TASK_TYPE_COMPLETION;
task.multitask_id = multitask_id; task.multitask_id = multitask_id;
@@ -1393,8 +1415,8 @@ struct llama_server_context
json subtask_data = multiprompt_task.data; json subtask_data = multiprompt_task.data;
subtask_data["prompt"] = subtask_data["prompt"][i]; subtask_data["prompt"] = subtask_data["prompt"][i];
// subtasks inherit everything else (embedding mode, etc.) // subtasks inherit everything else (infill mode, embedding mode, etc.)
request_completion(subtask_ids[i], subtask_data, multiprompt_task.embedding_mode, multitask_id); request_completion(subtask_ids[i], subtask_data, multiprompt_task.infill_mode, multiprompt_task.embedding_mode, multitask_id);
} }
} }
@@ -1412,8 +1434,26 @@ struct llama_server_context
break; break;
} }
if (task.data.contains("system_prompt"))
{
if (!all_slots_are_idle) {
send_error(task, "system prompt can only be updated when all slots are idle");
break;
}
system_prompt_process(task.data["system_prompt"]);
// reset cache_tokens for all slots
for (server_slot &slot : slots)
{
slot.cache_tokens.clear();
slot.n_past = 0;
slot.n_past_se = 0;
}
}
slot->reset(); slot->reset();
slot->infill = task.infill_mode;
slot->embedding = task.embedding_mode; slot->embedding = task.embedding_mode;
slot->task_id = task.id; slot->task_id = task.id;
slot->multitask_id = task.multitask_id; slot->multitask_id = task.multitask_id;
@@ -1639,7 +1679,8 @@ struct llama_server_context
const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty(); const bool has_prompt = slot.prompt.is_array() || (slot.prompt.is_string() && !slot.prompt.get<std::string>().empty()) || !slot.images.empty();
// empty prompt passed -> release the slot and send empty response // empty prompt passed -> release the slot and send empty response
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt) // note: infill mode allows empty prompt
if (slot.state == IDLE && slot.command == LOAD_PROMPT && !has_prompt && !slot.infill)
{ {
slot.release(); slot.release();
slot.print_timings(); slot.print_timings();
@@ -1656,7 +1697,33 @@ struct llama_server_context
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0; slot.t_start_genereration = 0;
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt if (slot.infill)
{
bool suff_rm_leading_spc = true;
if (params.input_suffix.find_first_of(' ') == 0 && params.input_suffix.size() > 1)
{
params.input_suffix.erase(0, 1);
suff_rm_leading_spc = false;
}
auto prefix_tokens = tokenize(slot.params.input_prefix, false);
auto suffix_tokens = tokenize(slot.params.input_suffix, false);
const int space_token = 29871; // TODO: this should not be hardcoded
if (suff_rm_leading_spc && !suffix_tokens.empty() && suffix_tokens[0] == space_token) {
suffix_tokens.erase(suffix_tokens.begin());
}
prefix_tokens.insert(prefix_tokens.begin(), llama_token_prefix(model));
prefix_tokens.insert(prefix_tokens.begin(), llama_token_bos(model)); // always add BOS
prefix_tokens.insert(prefix_tokens.end(), llama_token_suffix(model));
prefix_tokens.insert(prefix_tokens.end(), suffix_tokens.begin(), suffix_tokens.end());
prefix_tokens.push_back(llama_token_middle(model));
prompt_tokens = prefix_tokens;
}
else
{
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
}
slot.n_prompt_tokens = prompt_tokens.size(); slot.n_prompt_tokens = prompt_tokens.size();
@@ -2063,7 +2130,8 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
printf("\n"); printf("\n");
} }
static void server_params_parse(int argc, char **argv, server_params &sparams, gpt_params &params) static void server_params_parse(int argc, char **argv, server_params &sparams,
gpt_params &params, llama_server_context& llama)
{ {
gpt_params default_params; gpt_params default_params;
server_params default_sparams; server_params default_sparams;
@@ -2478,6 +2546,27 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
} }
params.n_predict = std::stoi(argv[i]); params.n_predict = std::stoi(argv[i]);
} }
else if (arg == "-spf" || arg == "--system-prompt-file")
{
if (++i >= argc)
{
invalid_param = true;
break;
}
std::ifstream file(argv[i]);
if (!file) {
fprintf(stderr, "error: failed to open file '%s'\n", argv[i]);
invalid_param = true;
break;
}
std::string systm_content;
std::copy(
std::istreambuf_iterator<char>(file),
std::istreambuf_iterator<char>(),
std::back_inserter(systm_content)
);
llama.system_prompt_process(json::parse(systm_content));
}
else if (arg == "-ctk" || arg == "--cache-type-k") { else if (arg == "-ctk" || arg == "--cache-type-k") {
params.cache_type_k = argv[++i]; params.cache_type_k = argv[++i];
} }
@@ -2729,7 +2818,7 @@ int main(int argc, char **argv) {
// struct that contains llama context and inference // struct that contains llama context and inference
llama_server_context llama; llama_server_context llama;
server_params_parse(argc, argv, sparams, params); server_params_parse(argc, argv, sparams, params, llama);
if (params.model_alias == "unknown") if (params.model_alias == "unknown")
{ {
@@ -3061,7 +3150,7 @@ int main(int argc, char **argv) {
json data = json::parse(req.body); json data = json::parse(req.body);
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, data, false, -1); llama.request_completion(task_id, data, false, false, -1);
if (!json_value(data, "stream", false)) { if (!json_value(data, "stream", false)) {
std::string completion_text; std::string completion_text;
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);
@@ -3183,7 +3272,7 @@ int main(int argc, char **argv) {
// create and queue the task // create and queue the task
const int task_id = llama.queue_tasks.get_new_id(); const int task_id = llama.queue_tasks.get_new_id();
llama.queue_results.add_waiting_task_id(task_id); llama.queue_results.add_waiting_task_id(task_id);
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, true, -1); llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
// get the result // get the result
task_result result = llama.queue_results.recv(task_id); task_result result = llama.queue_results.recv(task_id);

View File

@@ -32,43 +32,42 @@ case "${GOARCH}" in
echo "Building static library" echo "Building static library"
build build
if [ -z "$OLLAMA_SKIP_CPU_GENERATE" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
# #
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
# Approximately 400% faster than LCD on same CPU #
# init_vars
init_vars CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" BUILD_DIR="../build/darwin/${ARCH}/cpu"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx" echo "Building LCD CPU"
echo "Building AVX CPU" build
build sign ${BUILD_DIR}/bin/ollama_llama_server
sign ${BUILD_DIR}/bin/ollama_llama_server compress
compress
# #
# ~2013 CPU Dynamic library # ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 10% faster than AVX on same CPU # Approximately 400% faster than LCD on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2" BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX2 CPU" echo "Building AVX CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation" build
build sign ${BUILD_DIR}/bin/ollama_llama_server
sign ${BUILD_DIR}/bin/ollama_llama_server compress
compress
fi #
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;; ;;
"arm64") "arm64")
@@ -80,15 +79,13 @@ case "${GOARCH}" in
echo "Building static library" echo "Building static library"
build build
if [ -z "$OLLAMA_SKIP_METAL_GENERATE" ]; then init_vars
init_vars CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}" BUILD_DIR="../build/darwin/${ARCH}/metal"
BUILD_DIR="../build/darwin/${ARCH}/metal" EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders" build
build sign ${BUILD_DIR}/bin/ollama_llama_server
sign ${BUILD_DIR}/bin/ollama_llama_server compress
compress
fi
;; ;;
*) *)
echo "GOARCH must be set" echo "GOARCH must be set"

View File

@@ -211,7 +211,7 @@ if [ -z "${ONEAPI_ROOT}" ]; then
ONEAPI_ROOT=/opt/intel/oneapi ONEAPI_ROOT=/opt/intel/oneapi
fi fi
if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then if [ -d "${ONEAPI_ROOT}" ]; then
echo "OneAPI libraries detected - building dynamic OneAPI library" echo "OneAPI libraries detected - building dynamic OneAPI library"
init_vars init_vars
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI

View File

@@ -290,7 +290,7 @@ function build_cuda() {
} }
function build_oneapi() { function build_oneapi() {
if ((-not "${env:OLLAMA_SKIP_ONEAPI_GENERATE}") -and ("${env:ONEAPI_ROOT}")) { if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
# Get oneAPI version # Get oneAPI version
$script:ONEAPI_VERSION = icpx --version $script:ONEAPI_VERSION = icpx --version
$script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value $script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value

View File

@@ -81,11 +81,6 @@ func (kv KV) ContextLength() uint64 {
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture())) return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
} }
func (kv KV) ChatTemplate() string {
s, _ := kv["tokenizer.chat_template"].(string)
return s
}
type Tensors []*Tensor type Tensors []*Tensor
func (ts Tensors) Layers() map[string]Layer { func (ts Tensors) Layers() map[string]Layer {

View File

@@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
var dims int dims := 0
for cnt := range len(tensor.Shape) { for cnt := 0; cnt < len(tensor.Shape); cnt++ {
if tensor.Shape[cnt] > 0 { if tensor.Shape[cnt] > 0 {
dims++ dims++
} }
@@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
return err return err
} }
for i := range dims { for i := 0; i < dims; i++ {
if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil { if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
return err return err
} }
} }
@@ -618,8 +618,22 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
} }
} }
offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil {
return err
}
var alignment int64 = 32 var alignment int64 = 32
padding := llm.padding(offset, alignment)
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err
}
for _, tensor := range tensors { for _, tensor := range tensors {
if _, err := tensor.WriteTo(ws); err != nil {
return err
}
offset, err := ws.Seek(0, io.SeekCurrent) offset, err := ws.Seek(0, io.SeekCurrent)
if err != nil { if err != nil {
return err return err
@@ -629,10 +643,6 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil { if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
return err return err
} }
if _, err := tensor.WriteTo(ws); err != nil {
return err
}
} }
return nil return nil

View File

@@ -5,9 +5,9 @@ import (
"log/slog" "log/slog"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/envconfig"
) )
// This algorithm looks for a complete fit to determine if we need to unload other models // This algorithm looks for a complete fit to determine if we need to unload other models
@@ -103,7 +103,7 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
} }
var layerCount int var layerCount int
for i := range int(ggml.KV().BlockCount()) { for i := 0; i < int(ggml.KV().BlockCount()); i++ {
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok { if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
memoryLayer := blk.size() memoryLayer := blk.size()

View File

@@ -1,32 +1,35 @@
From d02a06f3f45a09255ace8684a66590e06ce44605 Mon Sep 17 00:00:00 2001
From: Michael Yang <mxyng@pm.me>
Date: Thu, 23 May 2024 11:33:20 -0700
Subject: [PATCH] default pretokenizer on unrecognized type
---
llama.cpp | 5 +----
1 file changed, 1 insertion(+), 4 deletions(-)
diff --git a/llama.cpp b/llama.cpp diff --git a/llama.cpp b/llama.cpp
index 40d2ec2c..74f3ee9c 100644 index 15c66077..af1aede3 100644
--- a/llama.cpp --- a/llama.cpp
+++ b/llama.cpp +++ b/llama.cpp
@@ -4642,16 +4642,7 @@ static void llm_load_vocab( @@ -4504,9 +4504,6 @@ static void llm_load_vocab(
LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
// for now, only BPE models have pre-tokenizers LLAMA_LOG_WARN("%s: \n", __func__);
if (vocab.type == LLAMA_VOCAB_TYPE_BPE) {
- if (tokenizer_pre.empty()) {
- LLAMA_LOG_WARN("%s: missing pre-tokenizer type, using: 'default'\n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: GENERATION QUALITY WILL BE DEGRADED! \n", __func__);
- LLAMA_LOG_WARN("%s: CONSIDER REGENERATING THE MODEL \n", __func__);
- LLAMA_LOG_WARN("%s: ************************************ \n", __func__);
- LLAMA_LOG_WARN("%s: \n", __func__);
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
+ if (
tokenizer_pre == "default") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
- } else if (
- tokenizer_pre == "default") {
- vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} else if ( } else if (
@@ -4703,7 +4694,8 @@ static void llm_load_vocab( tokenizer_pre == "llama3" ||
tokenizer_pre == "smaug-bpe") { tokenizer_pre == "llama-v3" ||
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_SMAUG; @@ -4553,7 +4550,7 @@ static void llm_load_vocab(
tokenizer_pre == "dbrx") {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DBRX;
} else { } else {
- throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); - throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
+ LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
+ vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; + vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
} }
} else { } else {
vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT; vocab.type_pre = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
--
2.45.1

View File

@@ -1,13 +0,0 @@
diff --git a/llama.cpp b/llama.cpp
index 40d2ec2c..f34eb79a 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il);
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);

View File

@@ -10,9 +10,9 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices"
"strings" "strings"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"

View File

@@ -85,6 +85,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var systemMemory uint64 var systemMemory uint64
gpuCount := len(gpus) gpuCount := len(gpus)
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 { if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner // TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
cpuRunner = serverForCpu() cpuRunner = serverForCpu()
@@ -103,22 +104,21 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
var layers int var layers int
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts) layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
switch { if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
case gpus[0].Library == "metal" && estimatedVRAM > systemMemory:
// disable partial offloading when model is greater than total system memory as this // disable partial offloading when model is greater than total system memory as this
// can lead to locking up the system // can lead to locking up the system
opts.NumGPU = 0 opts.NumGPU = 0
case gpus[0].Library != "metal" && layers == 0: } else if gpus[0].Library != "metal" && layers == 0 {
// Don't bother loading into the GPU if no layers can fit // Don't bother loading into the GPU if no layers can fit
cpuRunner = serverForCpu() cpuRunner = serverForCpu()
gpuCount = 0 gpuCount = 0
case opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu": } else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
opts.NumGPU = layers opts.NumGPU = layers
} }
} }
// Loop through potential servers // Loop through potential servers
finalErr := errors.New("no suitable llama servers found") finalErr := fmt.Errorf("no suitable llama servers found")
if len(adapters) > 1 { if len(adapters) > 1 {
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided") return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
@@ -189,38 +189,35 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--memory-f32") params = append(params, "--memory-f32")
} }
flashAttnEnabled := envconfig.FlashAttention if opts.UseMLock {
params = append(params, "--mlock")
for _, g := range gpus {
// only cuda (compute capability 7+) and metal support flash attention
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
// mmap has issues with partial offloading on metal
if g.Library == "metal" &&
uint64(opts.NumGPU) > 0 &&
uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
opts.UseMMap = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
} }
if !opts.UseMMap { if !opts.UseMMap {
params = append(params, "--no-mmap") params = append(params, "--no-mmap")
} }
if opts.UseMLock {
params = append(params, "--mlock")
}
if opts.UseNUMA { if opts.UseNUMA {
params = append(params, "--numa") params = append(params, "--numa")
} }
flashAttnEnabled := envconfig.FlashAttention
// partial offloading does not support flash attention
if uint64(opts.NumGPU) < ggml.KV().BlockCount()+1 {
flashAttnEnabled = false
}
// only cuda (compute capability 7+) and metal support flash attention
for _, g := range gpus {
if g.Library != "metal" && (g.Library != "cuda" || g.DriverMajor < 7) {
flashAttnEnabled = false
}
}
if flashAttnEnabled {
params = append(params, "--flash-attn")
}
numParallel := envconfig.NumParallel numParallel := envconfig.NumParallel
// TODO (jmorganca): multimodal models don't support parallel yet // TODO (jmorganca): multimodal models don't support parallel yet
@@ -232,7 +229,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel)) params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
for i := range len(servers) { for i := 0; i < len(servers); i++ {
dir := availableServers[servers[i]] dir := availableServers[servers[i]]
if dir == "" { if dir == "" {
// Shouldn't happen // Shouldn't happen
@@ -284,7 +281,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
server := filepath.Join(dir, "ollama_llama_server") server := filepath.Join(dir, "ollama_llama_server")
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
server += ".exe" server = server + ".exe"
} }
// Detect tmp cleaners wiping out the file // Detect tmp cleaners wiping out the file
@@ -315,7 +312,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
s.cmd.Stdout = os.Stdout s.cmd.Stdout = os.Stdout
s.cmd.Stderr = s.status s.cmd.Stderr = s.status
visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv() visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
// Update or add the path and visible devices variable with our adjusted version // Update or add the path and visible devices variable with our adjusted version
@@ -459,7 +456,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
resp, err := http.DefaultClient.Do(req) resp, err := http.DefaultClient.Do(req)
if err != nil { if err != nil {
if errors.Is(err, context.DeadlineExceeded) { if errors.Is(err, context.DeadlineExceeded) {
return ServerStatusNotResponding, errors.New("server not responding") return ServerStatusNotResponding, fmt.Errorf("server not responding")
} }
return ServerStatusError, fmt.Errorf("health resp: %w", err) return ServerStatusError, fmt.Errorf("health resp: %w", err)
} }

View File

@@ -245,6 +245,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
d, err := json.Marshal(toChunk(w.id, chatResponse)) d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil { if err != nil {
return 0, err return 0, err
} }
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream") w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")

View File

@@ -10,7 +10,6 @@ import (
"unicode/utf16" "unicode/utf16"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestParseFileFile(t *testing.T) { func TestParseFileFile(t *testing.T) {
@@ -26,7 +25,7 @@ TEMPLATE template1
reader := strings.NewReader(input) reader := strings.NewReader(input)
modelfile, err := ParseFile(reader) modelfile, err := ParseFile(reader)
require.NoError(t, err) assert.NoError(t, err)
expectedCommands := []Command{ expectedCommands := []Command{
{Name: "model", Args: "model1"}, {Name: "model", Args: "model1"},
@@ -89,7 +88,7 @@ func TestParseFileFrom(t *testing.T) {
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
require.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@@ -106,7 +105,7 @@ PARAMETER param1
reader := strings.NewReader(input) reader := strings.NewReader(input)
_, err := ParseFile(reader) _, err := ParseFile(reader)
require.ErrorIs(t, err, io.ErrUnexpectedEOF) assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
} }
func TestParseFileBadCommand(t *testing.T) { func TestParseFileBadCommand(t *testing.T) {
@@ -115,7 +114,8 @@ FROM foo
BADCOMMAND param1 value1 BADCOMMAND param1 value1
` `
_, err := ParseFile(strings.NewReader(input)) _, err := ParseFile(strings.NewReader(input))
require.ErrorIs(t, err, errInvalidCommand) assert.ErrorIs(t, err, errInvalidCommand)
} }
func TestParseFileMessages(t *testing.T) { func TestParseFileMessages(t *testing.T) {
@@ -201,7 +201,7 @@ MESSAGE system`,
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
require.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@@ -355,7 +355,7 @@ TEMPLATE """
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.multiline)) modelfile, err := ParseFile(strings.NewReader(c.multiline))
require.ErrorIs(t, err, c.err) assert.ErrorIs(t, err, c.err)
if modelfile != nil { if modelfile != nil {
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
} }
@@ -413,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
fmt.Fprintln(&b, "FROM foo") fmt.Fprintln(&b, "FROM foo")
fmt.Fprintln(&b, "PARAMETER", k) fmt.Fprintln(&b, "PARAMETER", k)
modelfile, err := ParseFile(&b) modelfile, err := ParseFile(&b)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, []Command{ assert.Equal(t, []Command{
{Name: "model", Args: "foo"}, {Name: "model", Args: "foo"},
@@ -442,7 +442,7 @@ FROM foo
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c.input)) modelfile, err := ParseFile(strings.NewReader(c.input))
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, c.expected, modelfile.Commands) assert.Equal(t, c.expected, modelfile.Commands)
}) })
} }
@@ -501,14 +501,15 @@ SYSTEM ""
for _, c := range cases { for _, c := range cases {
t.Run("", func(t *testing.T) { t.Run("", func(t *testing.T) {
modelfile, err := ParseFile(strings.NewReader(c)) modelfile, err := ParseFile(strings.NewReader(c))
require.NoError(t, err) assert.NoError(t, err)
modelfile2, err := ParseFile(strings.NewReader(modelfile.String())) modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, modelfile, modelfile2) assert.Equal(t, modelfile, modelfile2)
}) })
} }
} }
func TestParseFileUTF16ParseFile(t *testing.T) { func TestParseFileUTF16ParseFile(t *testing.T) {
@@ -521,10 +522,10 @@ SYSTEM You are a utf16 file.
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...)) utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
err := binary.Write(buf, binary.LittleEndian, utf16File) err := binary.Write(buf, binary.LittleEndian, utf16File)
require.NoError(t, err) assert.NoError(t, err)
actual, err := ParseFile(buf) actual, err := ParseFile(buf)
require.NoError(t, err) assert.NoError(t, err)
expected := []Command{ expected := []Command{
{Name: "model", Args: "bob"}, {Name: "model", Args: "bob"},
@@ -538,9 +539,9 @@ SYSTEM You are a utf16 file.
// simulate a utf16 be file // simulate a utf16 be file
buf = new(bytes.Buffer) buf = new(bytes.Buffer)
err = binary.Write(buf, binary.BigEndian, utf16File) err = binary.Write(buf, binary.BigEndian, utf16File)
require.NoError(t, err) assert.NoError(t, err)
actual, err = ParseFile(buf) actual, err = ParseFile(buf)
require.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, expected, actual.Commands) assert.Equal(t, expected, actual.Commands)
} }

View File

@@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
stopped := p.stop() stopped := p.stop()
if stopped { if stopped {
// clear all progress lines // clear all progress lines
for i := range p.pos { for i := 0; i < p.pos; i++ {
if i > 0 { if i > 0 {
fmt.Fprint(p.w, "\033[A") fmt.Fprint(p.w, "\033[A")
} }
@@ -85,7 +85,7 @@ func (p *Progress) render() {
defer fmt.Fprint(p.w, "\033[?25h") defer fmt.Fprint(p.w, "\033[?25h")
// clear already rendered progress lines // clear already rendered progress lines
for i := range p.pos { for i := 0; i < p.pos; i++ {
if i > 0 { if i > 0 {
fmt.Fprint(p.w, "\033[A") fmt.Fprint(p.w, "\033[A")
} }

View File

@@ -52,6 +52,7 @@ func (b *Buffer) GetLineSpacing(line int) bool {
} }
return hasSpace.(bool) return hasSpace.(bool)
} }
func (b *Buffer) MoveLeft() { func (b *Buffer) MoveLeft() {
@@ -116,12 +117,15 @@ func (b *Buffer) MoveRight() {
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength)) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
b.DisplayPos += 1 b.DisplayPos += 1
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace { } else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt()))) fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
b.DisplayPos += 1 b.DisplayPos += 1
} else { } else {
fmt.Print(cursorRightN(rLength)) fmt.Print(cursorRightN(rLength))
} }
@@ -150,7 +154,7 @@ func (b *Buffer) MoveToStart() {
if b.Pos > 0 { if b.Pos > 0 {
currLine := b.DisplayPos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
if currLine > 0 { if currLine > 0 {
for range currLine { for cnt := 0; cnt < currLine; cnt++ {
fmt.Print(CursorUp) fmt.Print(CursorUp)
} }
} }
@@ -165,7 +169,7 @@ func (b *Buffer) MoveToEnd() {
currLine := b.DisplayPos / b.LineWidth currLine := b.DisplayPos / b.LineWidth
totalLines := b.DisplaySize() / b.LineWidth totalLines := b.DisplaySize() / b.LineWidth
if currLine < totalLines { if currLine < totalLines {
for range totalLines - currLine { for cnt := 0; cnt < totalLines-currLine; cnt++ {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
remainder := b.DisplaySize() % b.LineWidth remainder := b.DisplaySize() % b.LineWidth
@@ -181,7 +185,7 @@ func (b *Buffer) MoveToEnd() {
func (b *Buffer) DisplaySize() int { func (b *Buffer) DisplaySize() int {
sum := 0 sum := 0
for i := range b.Buf.Size() { for i := 0; i < b.Buf.Size(); i++ {
if e, ok := b.Buf.Get(i); ok { if e, ok := b.Buf.Get(i); ok {
if r, ok := e.(rune); ok { if r, ok := e.(rune); ok {
sum += runewidth.RuneWidth(r) sum += runewidth.RuneWidth(r)
@@ -193,6 +197,7 @@ func (b *Buffer) DisplaySize() int {
} }
func (b *Buffer) Add(r rune) { func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() { if b.Pos == b.Buf.Size() {
b.AddChar(r, false) b.AddChar(r, false)
} else { } else {
@@ -205,6 +210,7 @@ func (b *Buffer) AddChar(r rune, insert bool) {
b.DisplayPos += rLength b.DisplayPos += rLength
if b.Pos > 0 { if b.Pos > 0 {
if b.DisplayPos%b.LineWidth == 0 { if b.DisplayPos%b.LineWidth == 0 {
fmt.Printf("%c", r) fmt.Printf("%c", r)
fmt.Printf("\n%s", b.Prompt.AltPrompt) fmt.Printf("\n%s", b.Prompt.AltPrompt)
@@ -229,6 +235,7 @@ func (b *Buffer) AddChar(r rune, insert bool) {
} else { } else {
b.LineHasSpace.Add(true) b.LineHasSpace.Add(true)
} }
} else { } else {
fmt.Printf("%c", r) fmt.Printf("%c", r)
} }
@@ -349,6 +356,7 @@ func (b *Buffer) drawRemaining() {
func (b *Buffer) Remove() { func (b *Buffer) Remove() {
if b.Buf.Size() > 0 && b.Pos > 0 { if b.Buf.Size() > 0 && b.Pos > 0 {
if e, ok := b.Buf.Get(b.Pos - 1); ok { if e, ok := b.Buf.Get(b.Pos - 1); ok {
if r, ok := e.(rune); ok { if r, ok := e.(rune); ok {
rLength := runewidth.RuneWidth(r) rLength := runewidth.RuneWidth(r)
@@ -374,6 +382,7 @@ func (b *Buffer) Remove() {
} else { } else {
fmt.Print(" " + CursorLeft) fmt.Print(" " + CursorLeft)
} }
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace { } else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
fmt.Printf(CursorBOL + ClearToEOL) fmt.Printf(CursorBOL + ClearToEOL)
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width)) fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
@@ -382,9 +391,10 @@ func (b *Buffer) Remove() {
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1) b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
} }
b.DisplayPos -= 1 b.DisplayPos -= 1
} else { } else {
fmt.Print(cursorLeftN(rLength)) fmt.Print(cursorLeftN(rLength))
for range rLength { for i := 0; i < rLength; i++ {
fmt.Print(" ") fmt.Print(" ")
} }
fmt.Print(cursorLeftN(rLength)) fmt.Print(cursorLeftN(rLength))
@@ -441,7 +451,7 @@ func (b *Buffer) DeleteBefore() {
func (b *Buffer) DeleteRemaining() { func (b *Buffer) DeleteRemaining() {
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() { if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
charsToDel := b.Buf.Size() - b.Pos charsToDel := b.Buf.Size() - b.Pos
for range charsToDel { for cnt := 0; cnt < charsToDel; cnt++ {
b.Delete() b.Delete()
} }
} }
@@ -485,7 +495,7 @@ func (b *Buffer) ClearScreen() {
if currPos > 0 { if currPos > 0 {
targetLine := currPos / b.LineWidth targetLine := currPos / b.LineWidth
if targetLine > 0 { if targetLine > 0 {
for range targetLine { for cnt := 0; cnt < targetLine; cnt++ {
fmt.Print(CursorDown) fmt.Print(CursorDown)
} }
} }
@@ -515,7 +525,7 @@ func (b *Buffer) Replace(r []rune) {
fmt.Printf(CursorBOL + ClearToEOL) fmt.Printf(CursorBOL + ClearToEOL)
for range lineNums { for i := 0; i < lineNums; i++ {
fmt.Print(CursorUp + CursorBOL + ClearToEOL) fmt.Print(CursorUp + CursorBOL + ClearToEOL)
} }

View File

@@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
func (h *History) Compact() { func (h *History) Compact() {
s := h.Buf.Size() s := h.Buf.Size()
if s > h.Limit { if s > h.Limit {
for range s - h.Limit { for cnt := 0; cnt < s-h.Limit; cnt++ {
h.Buf.Remove(0) h.Buf.Remove(0)
} }
} }
@@ -139,7 +139,7 @@ func (h *History) Save() error {
defer f.Close() defer f.Close()
buf := bufio.NewWriter(f) buf := bufio.NewWriter(f)
for cnt := range h.Size() { for cnt := 0; cnt < h.Size(); cnt++ {
v, _ := h.Buf.Get(cnt) v, _ := h.Buf.Get(cnt)
line, _ := v.([]rune) line, _ := v.([]rune)
if _, err := buf.WriteString(string(line) + "\n"); err != nil { if _, err := buf.WriteString(string(line) + "\n"); err != nil {

View File

@@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"io" "io"
"os" "os"
"syscall"
) )
type Prompt struct { type Prompt struct {
@@ -62,7 +63,7 @@ func New(prompt Prompt) (*Instance, error) {
func (i *Instance) Readline() (string, error) { func (i *Instance) Readline() (string, error) {
if !i.Terminal.rawmode { if !i.Terminal.rawmode {
fd := os.Stdin.Fd() fd := int(syscall.Stdin)
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)
if err != nil { if err != nil {
return "", err return "", err
@@ -79,8 +80,8 @@ func (i *Instance) Readline() (string, error) {
fmt.Print(prompt) fmt.Print(prompt)
defer func() { defer func() {
fd := os.Stdin.Fd() fd := int(syscall.Stdin)
//nolint:errcheck // nolint: errcheck
UnsetRawMode(fd, i.Terminal.termios) UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false i.Terminal.rawmode = false
}() }()
@@ -135,7 +136,7 @@ func (i *Instance) Readline() (string, error) {
buf.MoveRight() buf.MoveRight()
case CharBracketedPaste: case CharBracketedPaste:
var code string var code string
for range 3 { for cnt := 0; cnt < 3; cnt++ {
r, err = i.Terminal.Read() r, err = i.Terminal.Read()
if err != nil { if err != nil {
return "", io.EOF return "", io.EOF
@@ -197,7 +198,7 @@ func (i *Instance) Readline() (string, error) {
buf.Remove() buf.Remove()
case CharTab: case CharTab:
// todo: convert back to real tabs // todo: convert back to real tabs
for range 8 { for cnt := 0; cnt < 8; cnt++ {
buf.Add(' ') buf.Add(' ')
} }
case CharDelete: case CharDelete:
@@ -215,7 +216,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlW: case CharCtrlW:
buf.DeleteWord() buf.DeleteWord()
case CharCtrlZ: case CharCtrlZ:
fd := os.Stdin.Fd() fd := int(syscall.Stdin)
return handleCharCtrlZ(fd, i.Terminal.termios) return handleCharCtrlZ(fd, i.Terminal.termios)
case CharEnter, CharCtrlJ: case CharEnter, CharCtrlJ:
output := buf.String() output := buf.String()
@@ -247,7 +248,7 @@ func (i *Instance) HistoryDisable() {
} }
func NewTerminal() (*Terminal, error) { func NewTerminal() (*Terminal, error) {
fd := os.Stdin.Fd() fd := int(syscall.Stdin)
termios, err := SetRawMode(fd) termios, err := SetRawMode(fd)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -6,7 +6,7 @@ import (
"syscall" "syscall"
) )
func handleCharCtrlZ(fd uintptr, termios any) (string, error) { func handleCharCtrlZ(fd int, termios any) (string, error) {
t := termios.(*Termios) t := termios.(*Termios)
if err := UnsetRawMode(fd, t); err != nil { if err := UnsetRawMode(fd, t); err != nil {
return "", err return "", err

View File

@@ -1,6 +1,6 @@
package readline package readline
func handleCharCtrlZ(fd uintptr, state any) (string, error) { func handleCharCtrlZ(fd int, state any) (string, error) {
// not supported // not supported
return "", nil return "", nil
} }

View File

@@ -8,7 +8,7 @@ import (
type Termios syscall.Termios type Termios syscall.Termios
func SetRawMode(fd uintptr) (*Termios, error) { func SetRawMode(fd int) (*Termios, error) {
termios, err := getTermios(fd) termios, err := getTermios(fd)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -25,13 +25,13 @@ func SetRawMode(fd uintptr) (*Termios, error) {
return termios, setTermios(fd, &newTermios) return termios, setTermios(fd, &newTermios)
} }
func UnsetRawMode(fd uintptr, termios any) error { func UnsetRawMode(fd int, termios any) error {
t := termios.(*Termios) t := termios.(*Termios)
return setTermios(fd, t) return setTermios(fd, t)
} }
// IsTerminal returns true if the given file descriptor is a terminal. // IsTerminal returns true if the given file descriptor is a terminal.
func IsTerminal(fd uintptr) bool { func IsTerminal(fd int) bool {
_, err := getTermios(fd) _, err := getTermios(fd)
return err == nil return err == nil
} }

View File

@@ -7,17 +7,17 @@ import (
"unsafe" "unsafe"
) )
func getTermios(fd uintptr) (*Termios, error) { func getTermios(fd int) (*Termios, error) {
termios := new(Termios) termios := new(Termios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 { if err != 0 {
return nil, err return nil, err
} }
return termios, nil return termios, nil
} }
func setTermios(fd uintptr, termios *Termios) error { func setTermios(fd int, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0) _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 { if err != 0 {
return err return err
} }

View File

@@ -10,17 +10,17 @@ import (
const tcgets = 0x5401 const tcgets = 0x5401
const tcsets = 0x5402 const tcsets = 0x5402
func getTermios(fd uintptr) (*Termios, error) { func getTermios(fd int) (*Termios, error) {
termios := new(Termios) termios := new(Termios)
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 { if err != 0 {
return nil, err return nil, err
} }
return termios, nil return termios, nil
} }
func setTermios(fd uintptr, termios *Termios) error { func setTermios(fd int, termios *Termios) error {
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0) _, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
if err != 0 { if err != 0 {
return err return err
} }

View File

@@ -9,13 +9,13 @@ type State struct {
} }
// IsTerminal checks if the given file descriptor is associated with a terminal // IsTerminal checks if the given file descriptor is associated with a terminal
func IsTerminal(fd uintptr) bool { func IsTerminal(fd int) bool {
var st uint32 var st uint32
err := windows.GetConsoleMode(windows.Handle(fd), &st) err := windows.GetConsoleMode(windows.Handle(fd), &st)
return err == nil return err == nil
} }
func SetRawMode(fd uintptr) (*State, error) { func SetRawMode(fd int) (*State, error) {
var st uint32 var st uint32
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil { if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
return nil, err return nil, err
@@ -32,7 +32,7 @@ func SetRawMode(fd uintptr) (*State, error) {
return &State{st}, nil return &State{st}, nil
} }
func UnsetRawMode(fd uintptr, state any) error { func UnsetRawMode(fd int, state any) error {
s := state.(*State) s := state.(*State)
return windows.SetConsoleMode(windows.Handle(fd), s.mode) return windows.SetConsoleMode(windows.Handle(fd), s.mode)
} }

View File

@@ -340,17 +340,17 @@ type downloadOpts struct {
} }
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) { func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)
if err != nil { if err != nil {
return false, err return err
} }
fi, err := os.Stat(fp) fi, err := os.Stat(fp)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
case err != nil: case err != nil:
return false, err return err
default: default:
opts.fn(api.ProgressResponse{ opts.fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]), Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
@@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
Completed: fi.Size(), Completed: fi.Size(),
}) })
return true, nil return nil
} }
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest}) data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
@@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest) requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil { if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
blobDownloadManager.Delete(opts.digest) blobDownloadManager.Delete(opts.digest)
return false, err return err
} }
//nolint:contextcheck // nolint: contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts) go download.Run(context.Background(), requestURL, opts.regOpts)
} }
return false, download.Wait(ctx, opts.fn) return download.Wait(ctx, opts.fn)
} }

View File

@@ -18,17 +18,17 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"slices"
"strconv" "strconv"
"strings" "strings"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/auth" "github.com/ollama/ollama/auth"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/templates" "github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/types/errtypes" "github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model" "github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
@@ -315,7 +315,7 @@ func realpath(rel, from string) string {
return abspath return abspath
} }
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) { func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@@ -435,46 +435,24 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount())) config.ModelType = cmp.Or(config.ModelType, format.HumanNumber(baseLayer.GGML.KV().ParameterCount()))
config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String()) config.FileType = cmp.Or(config.FileType, baseLayer.GGML.KV().FileType().String())
config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture()) config.ModelFamilies = append(config.ModelFamilies, baseLayer.GGML.KV().Architecture())
if s := baseLayer.GGML.KV().ChatTemplate(); s != "" {
if t, err := templates.NamedTemplate(s); err != nil {
slog.Debug("template detection", "error", err)
} else {
layer, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
if err != nil {
return err
}
layer.status = fmt.Sprintf("using autodetected template %s", t.Name)
layers = append(layers, layer)
}
}
} }
layers = append(layers, baseLayer.Layer) layers = append(layers, baseLayer.Layer)
} }
case "license", "template", "system": case "license", "template", "system":
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
if layer.MediaType != mediatype {
return false
}
if err := layer.Remove(); err != nil {
return false
}
return true
})
}
blob := strings.NewReader(c.Args) blob := strings.NewReader(c.Args)
layer, err := NewLayer(blob, mediatype) layer, err := NewLayer(blob, mediatype)
if err != nil { if err != nil {
return err return err
} }
if c.Name != "license" {
// replace
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
return layer.MediaType == mediatype
})
}
layers = append(layers, layer) layers = append(layers, layer)
case "message": case "message":
role, content, ok := strings.Cut(c.Args, ": ") role, content, ok := strings.Cut(c.Args, ": ")
@@ -593,15 +571,26 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
} }
} }
old, _ := ParseNamedManifest(name) unref := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range manifest.Layers {
if !slices.Contains(digests, layer.Digest) {
unref[layer.Digest] = struct{}{}
}
}
if manifest.Config.Digest != layer.Digest {
unref[manifest.Config.Digest] = struct{}{}
}
}
fn(api.ProgressResponse{Status: "writing manifest"}) fn(api.ProgressResponse{Status: "writing manifest"})
if err := WriteManifest(name, layer, layers); err != nil { if err := WriteManifest(name, layer, layers); err != nil {
return err return err
} }
if !envconfig.NoPrune && old != nil { if !envconfig.NoPrune {
if err := old.RemoveLayers(); err != nil { if err := deleteUnusedLayers(nil, unref); err != nil {
return err return err
} }
} }
@@ -673,7 +662,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
// save (i.e. delete from the deleteMap) any files used in other manifests // save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp) manifest, _, err := GetManifest(fmp)
if err != nil { if err != nil {
//nolint:nilerr // nolint: nilerr
return nil return nil
} }
@@ -782,6 +771,37 @@ func PruneDirectory(path string) error {
return nil return nil
} }
func DeleteModel(name string) error {
mp := ParseModelPath(name)
manifest, _, err := GetManifest(mp)
if err != nil {
return err
}
deleteMap := make(map[string]struct{})
for _, layer := range manifest.Layers {
deleteMap[layer.Digest] = struct{}{}
}
deleteMap[manifest.Config.Digest] = struct{}{}
err = deleteUnusedLayers(&mp, deleteMap)
if err != nil {
return err
}
fp, err := mp.GetManifestPath()
if err != nil {
return err
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
return err
}
return nil
}
func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error { func PushModel(ctx context.Context, name string, regOpts *registryOptions, fn func(api.ProgressResponse)) error {
mp := ParseModelPath(name) mp := ParseModelPath(name)
fn(api.ProgressResponse{Status: "retrieving manifest"}) fn(api.ProgressResponse{Status: "retrieving manifest"})
@@ -868,27 +888,23 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
layers = append(layers, manifest.Layers...) layers = append(layers, manifest.Layers...)
layers = append(layers, manifest.Config) layers = append(layers, manifest.Config)
skipVerify := make(map[string]bool)
for _, layer := range layers { for _, layer := range layers {
cacheHit, err := downloadBlob(ctx, downloadOpts{ if err := downloadBlob(
mp: mp, ctx,
digest: layer.Digest, downloadOpts{
regOpts: regOpts, mp: mp,
fn: fn, digest: layer.Digest,
}) regOpts: regOpts,
if err != nil { fn: fn,
}); err != nil {
return err return err
} }
skipVerify[layer.Digest] = cacheHit
delete(deleteMap, layer.Digest) delete(deleteMap, layer.Digest)
} }
delete(deleteMap, manifest.Config.Digest) delete(deleteMap, manifest.Config.Digest)
fn(api.ProgressResponse{Status: "verifying sha256 digest"}) fn(api.ProgressResponse{Status: "verifying sha256 digest"})
for _, layer := range layers { for _, layer := range layers {
if skipVerify[layer.Digest] {
continue
}
if err := verifyBlob(layer.Digest); err != nil { if err := verifyBlob(layer.Digest); err != nil {
if errors.Is(err, errDigestMismatch) { if errors.Is(err, errDigestMismatch) {
// something went wrong, delete the blob // something went wrong, delete the blob
@@ -1003,7 +1019,7 @@ func getTokenSubject(token string) string {
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) { func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
anonymous := true // access will default to anonymous if no user is found associated with the public key anonymous := true // access will default to anonymous if no user is found associated with the public key
for range 2 { for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts) resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil { if err != nil {
if !errors.Is(err, context.Canceled) { if !errors.Is(err, context.Canceled) {

View File

@@ -88,26 +88,3 @@ func (l *Layer) Open() (io.ReadSeekCloser, error) {
return os.Open(blob) return os.Open(blob)
} }
func (l *Layer) Remove() error {
ms, err := Manifests()
if err != nil {
return err
}
for _, m := range ms {
for _, layer := range append(m.Layers, m.Config) {
if layer.Digest == l.Digest {
// something is using this layer
return nil
}
}
}
blob, err := GetBlobsPath(l.Digest)
if err != nil {
return err
}
return os.Remove(blob)
}

View File

@@ -1,11 +1,11 @@
package server package server
import ( import (
"bytes"
"crypto/sha256" "crypto/sha256"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
"log/slog"
"os" "os"
"path/filepath" "path/filepath"
@@ -14,10 +14,7 @@ import (
type Manifest struct { type Manifest struct {
ManifestV2 ManifestV2
Digest string `json:"-"`
filepath string
fi os.FileInfo
digest string
} }
func (m *Manifest) Size() (size int64) { func (m *Manifest) Size() (size int64) {
@@ -28,32 +25,9 @@ func (m *Manifest) Size() (size int64) {
return return
} }
func (m *Manifest) Remove() error { func ParseNamedManifest(name model.Name) (*Manifest, error) {
if err := os.Remove(m.filepath); err != nil { if !name.IsFullyQualified() {
return err return nil, model.Unqualified(name)
}
manifests, err := GetManifestPath()
if err != nil {
return err
}
return PruneDirectory(manifests)
}
func (m *Manifest) RemoveLayers() error {
for _, layer := range append(m.Layers, m.Config) {
if err := layer.Remove(); err != nil {
return err
}
}
return nil
}
func ParseNamedManifest(n model.Name) (*Manifest, error) {
if !n.IsFullyQualified() {
return nil, model.Unqualified(n)
} }
manifests, err := GetManifestPath() manifests, err := GetManifestPath()
@@ -61,101 +35,45 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
return nil, err return nil, err
} }
p := filepath.Join(manifests, n.Filepath()) var manifest ManifestV2
manifestfile, err := os.Open(filepath.Join(manifests, name.Filepath()))
var m ManifestV2
f, err := os.Open(p)
if err != nil {
return nil, err
}
defer f.Close()
fi, err := f.Stat()
if err != nil { if err != nil {
return nil, err return nil, err
} }
sha256sum := sha256.New() sha256sum := sha256.New()
if err := json.NewDecoder(io.TeeReader(f, sha256sum)).Decode(&m); err != nil { if err := json.NewDecoder(io.TeeReader(manifestfile, sha256sum)).Decode(&manifest); err != nil {
return nil, err return nil, err
} }
return &Manifest{ return &Manifest{
ManifestV2: m, ManifestV2: manifest,
filepath: p, Digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
fi: fi,
digest: fmt.Sprintf("%x", sha256sum.Sum(nil)),
}, nil }, nil
} }
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error { func WriteManifest(name string, config *Layer, layers []*Layer) error {
manifests, err := GetManifestPath() manifest := ManifestV2{
if err != nil {
return err
}
p := filepath.Join(manifests, name.Filepath())
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
return err
}
f, err := os.Create(p)
if err != nil {
return err
}
defer f.Close()
m := ManifestV2{
SchemaVersion: 2, SchemaVersion: 2,
MediaType: "application/vnd.docker.distribution.manifest.v2+json", MediaType: "application/vnd.docker.distribution.manifest.v2+json",
Config: config, Config: config,
Layers: layers, Layers: layers,
} }
return json.NewEncoder(f).Encode(m) var b bytes.Buffer
} if err := json.NewEncoder(&b).Encode(manifest); err != nil {
return err
}
func Manifests() (map[model.Name]*Manifest, error) { modelpath := ParseModelPath(name)
manifests, err := GetManifestPath() manifestPath, err := modelpath.GetManifestPath()
if err != nil { if err != nil {
return nil, err return err
} }
// TODO(mxyng): use something less brittle if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
matches, err := filepath.Glob(filepath.Join(manifests, "*", "*", "*", "*")) return err
if err != nil {
return nil, err
} }
ms := make(map[model.Name]*Manifest) return os.WriteFile(manifestPath, b.Bytes(), 0o644)
for _, match := range matches {
fi, err := os.Stat(match)
if err != nil {
return nil, err
}
if !fi.IsDir() {
rel, err := filepath.Rel(manifests, match)
if err != nil {
slog.Warn("bad filepath", "path", match, "error", err)
continue
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest name", "path", rel, "error", err)
continue
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
continue
}
ms[n] = m
}
}
return ms, nil
} }

View File

@@ -1,150 +0,0 @@
package server
import (
"encoding/json"
"os"
"path/filepath"
"slices"
"testing"
"github.com/ollama/ollama/types/model"
)
func createManifest(t *testing.T, path, name string) {
t.Helper()
p := filepath.Join(path, "manifests", name)
if err := os.MkdirAll(filepath.Dir(p), 0755); err != nil {
t.Fatal(err)
}
f, err := os.Create(p)
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := json.NewEncoder(f).Encode(ManifestV2{}); err != nil {
t.Fatal(err)
}
}
func TestManifests(t *testing.T) {
cases := map[string]struct {
ps []string
wantValidCount int
wantInvalidCount int
}{
"empty": {},
"single": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
"multiple": {
ps: []string{
filepath.Join("registry.ollama.ai", "library", "llama3", "latest"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q8_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_0"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_1"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q2_K"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q3_K_L"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q4_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_S"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q5_K_M"),
filepath.Join("registry.ollama.ai", "library", "llama3", "q6_K"),
},
wantValidCount: 15,
},
"hidden": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag"),
filepath.Join("host", "namespace", "model", ".hidden"),
},
wantValidCount: 1,
wantInvalidCount: 1,
},
"subdir": {
ps: []string{
filepath.Join("host", "namespace", "model", "tag", "one"),
filepath.Join("host", "namespace", "model", "tag", "another", "one"),
},
wantInvalidCount: 2,
},
"upper tag": {
ps: []string{
filepath.Join("host", "namespace", "model", "TAG"),
},
wantValidCount: 1,
},
"upper model": {
ps: []string{
filepath.Join("host", "namespace", "MODEL", "tag"),
},
wantValidCount: 1,
},
"upper namespace": {
ps: []string{
filepath.Join("host", "NAMESPACE", "model", "tag"),
},
wantValidCount: 1,
},
"upper host": {
ps: []string{
filepath.Join("HOST", "namespace", "model", "tag"),
},
wantValidCount: 1,
},
}
for n, wants := range cases {
t.Run(n, func(t *testing.T) {
d := t.TempDir()
t.Setenv("OLLAMA_MODELS", d)
for _, p := range wants.ps {
createManifest(t, d, p)
}
ms, err := Manifests()
if err != nil {
t.Fatal(err)
}
var ns []model.Name
for k := range ms {
ns = append(ns, k)
}
var gotValidCount, gotInvalidCount int
for _, p := range wants.ps {
n := model.ParseNameFromFilepath(p)
if n.IsValid() {
gotValidCount++
} else {
gotInvalidCount++
}
if !n.IsValid() && slices.Contains(ns, n) {
t.Errorf("unexpected invalid name: %s", p)
} else if n.IsValid() && !slices.Contains(ns, n) {
t.Errorf("missing valid name: %s", p)
}
}
if gotValidCount != wants.wantValidCount {
t.Errorf("got valid count %d, want %d", gotValidCount, wants.wantValidCount)
}
if gotInvalidCount != wants.wantInvalidCount {
t.Errorf("got invalid count %d, want %d", gotInvalidCount, wants.wantInvalidCount)
}
})
}
}

View File

@@ -25,14 +25,16 @@ type layerWithGGML struct {
} }
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) { func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
m, err := ParseNamedManifest(name) modelpath := ParseModelPath(name.String())
manifest, _, err := GetManifest(modelpath)
switch { switch {
case errors.Is(err, os.ErrNotExist): case errors.Is(err, os.ErrNotExist):
if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil { if err := PullModel(ctx, name.String(), &registryOptions{}, fn); err != nil {
return nil, err return nil, err
} }
m, err = ParseNamedManifest(name) modelpath = ParseModelPath(name.String())
manifest, _, err = GetManifest(modelpath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -40,8 +42,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
return nil, err return nil, err
} }
for _, layer := range m.Layers { for _, layer := range manifest.Layers {
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest()) layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -70,6 +72,7 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
default: default:
layers = append(layers, &layerWithGGML{layer, nil}) layers = append(layers, &layerWithGGML{layer, nil})
} }
} }
return layers, nil return layers, nil

View File

@@ -6,13 +6,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestGetBlobsPath(t *testing.T) { func TestGetBlobsPath(t *testing.T) {
// GetBlobsPath expects an actual directory to exist // GetBlobsPath expects an actual directory to exist
dir, err := os.MkdirTemp("", "ollama-test") dir, err := os.MkdirTemp("", "ollama-test")
require.NoError(t, err) assert.Nil(t, err)
defer os.RemoveAll(dir) defer os.RemoveAll(dir)
tests := []struct { tests := []struct {
@@ -64,7 +63,7 @@ func TestGetBlobsPath(t *testing.T) {
got, err := GetBlobsPath(tc.digest) got, err := GetBlobsPath(tc.digest)
require.ErrorIs(t, tc.err, err, tc.name) assert.ErrorIs(t, tc.err, err, tc.name)
assert.Equal(t, tc.expected, got, tc.name) assert.Equal(t, tc.expected, got, tc.name)
}) })
} }

View File

@@ -16,7 +16,6 @@ import (
"os" "os"
"os/signal" "os/signal"
"path/filepath" "path/filepath"
"slices"
"strconv" "strconv"
"strings" "strings"
"syscall" "syscall"
@@ -24,6 +23,7 @@ import (
"github.com/gin-contrib/cors" "github.com/gin-contrib/cors"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig" "github.com/ollama/ollama/envconfig"
@@ -77,6 +77,7 @@ func isSupportedImageType(image []byte) bool {
} }
func (s *Server) GenerateHandler(c *gin.Context) { func (s *Server) GenerateHandler(c *gin.Context) {
checkpointStart := time.Now() checkpointStart := time.Now()
var req api.GenerateRequest var req api.GenerateRequest
err := c.ShouldBindJSON(&req) err := c.ShouldBindJSON(&req)
@@ -420,14 +421,13 @@ func (s *Server) PullModelHandler(c *gin.Context) {
return return
} }
name := model.ParseName(cmp.Or(req.Model, req.Name)) var model string
if !name.IsValid() { if req.Model != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "invalid model name"}) model = req.Model
return } else if req.Name != "" {
} model = req.Name
} else {
if err := checkNameExists(name); err != nil { c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
@@ -445,7 +445,7 @@ func (s *Server) PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
if err := PullModel(ctx, name.DisplayShortest(), regOpts, fn); err != nil { if err := PullModel(ctx, model, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
@@ -507,24 +507,9 @@ func (s *Server) PushModelHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
func checkNameExists(name model.Name) error {
names, err := Manifests()
if err != nil {
return err
}
for n := range names {
if strings.EqualFold(n.Filepath(), name.Filepath()) && n != name {
return fmt.Errorf("a model with that name already exists")
}
}
return nil
}
func (s *Server) CreateModelHandler(c *gin.Context) { func (s *Server) CreateModelHandler(c *gin.Context) {
var r api.CreateRequest var req api.CreateRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
} else if err != nil { } else if err != nil {
@@ -532,35 +517,30 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
return return
} }
name := model.ParseName(cmp.Or(r.Model, r.Name)) name := model.ParseName(cmp.Or(req.Model, req.Name))
if !name.IsValid() { if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
return return
} }
if err := checkNameExists(name); err != nil { if req.Path == "" && req.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if r.Path == "" && r.Modelfile == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
return return
} }
var sr io.Reader = strings.NewReader(r.Modelfile) var r io.Reader = strings.NewReader(req.Modelfile)
if r.Path != "" && r.Modelfile == "" { if req.Path != "" && req.Modelfile == "" {
f, err := os.Open(r.Path) f, err := os.Open(req.Path)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
return return
} }
defer f.Close() defer f.Close()
sr = f r = f
} }
f, err := parser.ParseFile(sr) modelfile, err := parser.ParseFile(r)
if err != nil { if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
@@ -576,13 +556,17 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context()) ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel() defer cancel()
quantization := cmp.Or(r.Quantize, r.Quantization) quantization := req.Quantization
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil { if req.Quantize != "" {
quantization = req.Quantize
}
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
ch <- gin.H{"error": err.Error()} ch <- gin.H{"error": err.Error()}
} }
}() }()
if r.Stream != nil && !*r.Stream { if req.Stream != nil && !*req.Stream {
waitForStream(c, ch) waitForStream(c, ch)
return return
} }
@@ -591,36 +575,48 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
} }
func (s *Server) DeleteModelHandler(c *gin.Context) { func (s *Server) DeleteModelHandler(c *gin.Context) {
var r api.DeleteRequest var req api.DeleteRequest
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) { err := c.ShouldBindJSON(&req)
switch {
case errors.Is(err, io.EOF):
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
return return
} else if err != nil { case err != nil:
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
n := model.ParseName(cmp.Or(r.Model, r.Name)) var model string
if !n.IsValid() { if req.Model != "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("name %q is invalid", cmp.Or(r.Model, r.Name))}) model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return return
} }
m, err := ParseNamedManifest(n) if err := DeleteModel(model); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
return
}
manifestsPath, err := GetManifestPath()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if err := m.Remove(); err != nil { if err := PruneDirectory(manifestsPath); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
if err := m.RemoveLayers(); err != nil { c.JSON(http.StatusOK, nil)
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
} }
func (s *Server) ShowModelHandler(c *gin.Context) { func (s *Server) ShowModelHandler(c *gin.Context) {
@@ -724,45 +720,75 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
func (s *Server) ListModelsHandler(c *gin.Context) { func (s *Server) ListModelsHandler(c *gin.Context) {
ms, err := Manifests() manifests, err := GetManifestPath()
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
models := []api.ListModelResponse{} models := []api.ModelResponse{}
for n, m := range ms { if err := filepath.Walk(manifests, func(path string, info os.FileInfo, _ error) error {
f, err := m.Config.Open() if !info.IsDir() {
if err != nil { rel, err := filepath.Rel(manifests, path)
slog.Warn("bad manifest filepath", "name", n, "error", err) if err != nil {
continue return err
} }
defer f.Close()
var cf ConfigV2 if hidden, err := filepath.Match(".*", filepath.Base(rel)); err != nil {
if err := json.NewDecoder(f).Decode(&cf); err != nil { return err
slog.Warn("bad manifest config", "name", n, "error", err) } else if hidden {
continue return nil
}
n := model.ParseNameFromFilepath(rel)
if !n.IsValid() {
slog.Warn("bad manifest filepath", "path", rel)
return nil
}
m, err := ParseNamedManifest(n)
if err != nil {
slog.Warn("bad manifest", "name", n, "error", err)
return nil
}
f, err := m.Config.Open()
if err != nil {
slog.Warn("bad manifest config filepath", "name", n, "error", err)
return nil
}
defer f.Close()
var c ConfigV2
if err := json.NewDecoder(f).Decode(&c); err != nil {
slog.Warn("bad manifest config", "name", n, "error", err)
return nil
}
// tag should never be masked
models = append(models, api.ModelResponse{
Model: n.DisplayShortest(),
Name: n.DisplayShortest(),
Size: m.Size(),
Digest: m.Digest,
ModifiedAt: info.ModTime(),
Details: api.ModelDetails{
Format: c.ModelFormat,
Family: c.ModelFamily,
Families: c.ModelFamilies,
ParameterSize: c.ModelType,
QuantizationLevel: c.FileType,
},
})
} }
// tag should never be masked return nil
models = append(models, api.ListModelResponse{ }); err != nil {
Model: n.DisplayShortest(), c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
Name: n.DisplayShortest(), return
Size: m.Size(),
Digest: m.digest,
ModifiedAt: m.fi.ModTime(),
Details: api.ModelDetails{
Format: cf.ModelFormat,
Family: cf.ModelFamily,
Families: cf.ModelFamilies,
ParameterSize: cf.ModelType,
QuantizationLevel: cf.FileType,
},
})
} }
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int { slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
// most recently modified first // most recently modified first
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix()) return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
}) })
@@ -792,11 +818,6 @@ func (s *Server) CopyModelHandler(c *gin.Context) {
return return
} }
if err := checkNameExists(dst); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) { if err := CopyModel(src, dst); errors.Is(err, os.ErrNotExist) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)}) c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model %q not found", r.Source)})
} else if err != nil { } else if err != nil {
@@ -942,7 +963,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
if allowedHost(host) { if allowedHost(host) {
if c.Request.Method == http.MethodOptions { if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent) c.AbortWithStatus(http.StatusNoContent)
return return
} }
@@ -960,10 +981,6 @@ func (s *Server) GenerateRoutes() http.Handler {
config.AllowWildcard = true config.AllowWildcard = true
config.AllowBrowserExtensions = true config.AllowBrowserExtensions = true
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"} config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
for _, prop := range openAIProperties {
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
}
config.AllowOrigins = envconfig.AllowOrigins config.AllowOrigins = envconfig.AllowOrigins
r := gin.Default() r := gin.Default()
@@ -1143,7 +1160,7 @@ func streamResponse(c *gin.Context, ch chan any) {
} }
func (s *Server) ProcessHandler(c *gin.Context) { func (s *Server) ProcessHandler(c *gin.Context) {
models := []api.ProcessModelResponse{} models := []api.ModelResponse{}
for _, v := range s.sched.loaded { for _, v := range s.sched.loaded {
model := v.model model := v.model
@@ -1155,7 +1172,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
QuantizationLevel: model.Config.FileType, QuantizationLevel: model.Config.FileType,
} }
mr := api.ProcessModelResponse{ mr := api.ModelResponse{
Model: model.ShortName, Model: model.ShortName,
Name: model.ShortName, Name: model.ShortName,
Size: int64(v.estimatedTotal), Size: int64(v.estimatedTotal),
@@ -1175,7 +1192,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
models = append(models, mr) models = append(models, mr)
} }
c.JSON(http.StatusOK, api.ProcessResponse{Models: models}) c.JSON(http.StatusOK, api.ListResponse{Models: models})
} }
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model // ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
@@ -1310,6 +1327,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
defer close(ch) defer close(ch)
fn := func(r llm.CompletionResponse) { fn := func(r llm.CompletionResponse) {
resp := api.ChatResponse{ resp := api.ChatResponse{
Model: req.Model, Model: req.Model,
CreatedAt: time.Now().UTC(), CreatedAt: time.Now().UTC(),

View File

@@ -1,560 +0,0 @@
package server
import (
"bytes"
"encoding/binary"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"slices"
"testing"
"github.com/gin-gonic/gin"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/llm"
)
var stream bool = false
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), "")
if err != nil {
t.Fatal(err)
}
defer f.Close()
if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
t.Fatal(err)
}
return f.Name()
}
type responseRecorder struct {
*httptest.ResponseRecorder
http.CloseNotifier
}
func NewRecorder() *responseRecorder {
return &responseRecorder{
ResponseRecorder: httptest.NewRecorder(),
}
}
func (t *responseRecorder) CloseNotify() <-chan bool {
return make(chan bool)
}
func createRequest(t *testing.T, fn func(*gin.Context), body any) *httptest.ResponseRecorder {
t.Helper()
w := NewRecorder()
c, _ := gin.CreateTestContext(w)
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(body); err != nil {
t.Fatal(err)
}
c.Request = &http.Request{
Body: io.NopCloser(&b),
}
fn(c)
return w.ResponseRecorder
}
func checkFileExists(t *testing.T, p string, expect []string) {
t.Helper()
actual, err := filepath.Glob(p)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(actual, expect) {
t.Fatalf("expected slices to be equal %v", actual)
}
}
func TestCreateFromBin(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateFromModel(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
}
func TestCreateRemovesLayers(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"),
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
}
func TestCreateUnsetsSystem(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"),
})
bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"))
if err != nil {
t.Fatal(err)
}
if string(bts) != "" {
t.Fatalf("expected empty string, actual %s", string(bts))
}
}
func TestCreateMergeParameters(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
})
// in order to merge parameters, the second model must be created FROM the first
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"),
})
actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"))
if err != nil {
t.Fatal(err)
}
expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
t.Errorf("expected %s, actual %s", string(expect), string(actual))
}
// slices are replaced
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"),
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
})
actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"))
if err != nil {
t.Fatal(err)
}
expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}})
if err != nil {
t.Fatal(err)
}
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
t.Errorf("expected %s, actual %s", string(expect), string(actual))
}
}
func TestCreateReplacesMessages(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
})
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"",
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"),
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
})
type message struct {
Role string `json:"role"`
Content string `json:"content"`
}
f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
var actual []message
if err := json.NewDecoder(f).Decode(&actual); err != nil {
t.Fatal(err)
}
expect := []message{
{Role: "assistant", Content: "You're a test, Harry."},
{Role: "user", Content: "I-I'm a what?"},
{Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."},
}
if !slices.Equal(actual, expect) {
t.Errorf("expected %s, actual %s", expect, actual)
}
}
func TestCreateTemplateSystem(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"),
filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"))
if err != nil {
t.Fatal(err)
}
if string(template) != "{{ .System }} {{ .Prompt }}" {
t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template)
}
system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"))
if err != nil {
t.Fatal(err)
}
if string(system) != "Say bye!" {
t.Errorf("expected \"Say bye!\", actual %s", system)
}
}
func TestCreateLicenses(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"),
})
mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"))
if err != nil {
t.Fatal(err)
}
if string(mit) != "MIT" {
t.Errorf("expected MIT, actual %s", mit)
}
apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"))
if err != nil {
t.Fatal(err)
}
if string(apache) != "Apache-2.0" {
t.Errorf("expected Apache-2.0, actual %s", apache)
}
}
func TestCreateDetectTemplate(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
t.Run("matched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
}, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-06cd2687a518d624073f125f1db1c5c727f77c75e84a138fe745186dbbbb4cd7"),
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
})
})
t.Run("unmatched", func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
})
})
}

View File

@@ -1,71 +0,0 @@
package server
import (
"fmt"
"net/http"
"path/filepath"
"testing"
"github.com/ollama/ollama/api"
)
func TestDelete(t *testing.T) {
p := t.TempDir()
t.Setenv("OLLAMA_MODELS", p)
var s Server
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test",
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: "test2",
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
})
w = createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test2"})
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
}

View File

@@ -1,61 +0,0 @@
package server
import (
"encoding/json"
"fmt"
"net/http"
"slices"
"testing"
"github.com/ollama/ollama/api"
)
func TestList(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
expectNames := []string{
"mistral:7b-instruct-q4_0",
"zephyr:7b-beta-q5_K_M",
"apple/OpenELM:latest",
"boreas:2b-code-v1.5-q6_K",
"notus:7b-v1-IQ2_S",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/library/eurus:700b-v0.5-iq3_XXS",
"mynamespace/apeliotes:latest",
"myhost/mynamespace/lips:code",
}
var s Server
for _, n := range expectNames {
createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: n,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
})
}
w := createRequest(t, s.ListModelsHandler, nil)
if w.Code != http.StatusOK {
t.Fatalf("expected status code 200, actual %d", w.Code)
}
var resp api.ListResponse
if err := json.NewDecoder(w.Body).Decode(&resp); err != nil {
t.Fatal(err)
}
if len(resp.Models) != len(expectNames) {
t.Fatalf("expected %d models, actual %d", len(expectNames), len(resp.Models))
}
actualNames := make([]string, len(resp.Models))
for i, m := range resp.Models {
actualNames[i] = m.Name
}
slices.Sort(actualNames)
slices.Sort(expectNames)
if !slices.Equal(actualNames, expectNames) {
t.Fatalf("expected slices to be equal %v", actualNames)
}
}

View File

@@ -15,36 +15,12 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/parser" "github.com/ollama/ollama/parser"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version" "github.com/ollama/ollama/version"
) )
func createTestFile(t *testing.T, name string) string {
t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
require.NoError(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
require.NoError(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
require.NoError(t, err)
return f.Name()
}
func Test_Routes(t *testing.T) { func Test_Routes(t *testing.T) {
type testCase struct { type testCase struct {
Name string Name string
@@ -54,19 +30,46 @@ func Test_Routes(t *testing.T) {
Expected func(t *testing.T, resp *http.Response) Expected func(t *testing.T, resp *http.Response)
} }
createTestModel := func(t *testing.T, name string) { createTestFile := func(t *testing.T, name string) string {
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), name)
assert.Nil(t, err)
defer f.Close()
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint32(3))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
err = binary.Write(f, binary.LittleEndian, uint64(0))
assert.Nil(t, err)
return f.Name()
}
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model") fname := createTestFile(t, "ollama-model")
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname)) r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile, err := parser.ParseFile(r) modelfile, err := parser.ParseFile(r)
require.NoError(t, err) assert.Nil(t, err)
fn := func(resp api.ProgressResponse) { fn := func(resp api.ProgressResponse) {
t.Logf("Status: %s", resp.Status) t.Logf("Status: %s", resp.Status)
} }
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn) err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
require.NoError(t, err) assert.Nil(t, err)
}
// Test Model Digests
blobs := []string{
"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99",
"sha256:4f9d252f34ae677363956ffc6dd2d10918a539c5c91f5ee2fe889d9178be6ae3",
"sha256:0f239b83e9e2aad7cd997a5bb44124937a32ac1f4e98e95a2f46e7b966bfc878",
} }
testCases := []testCase{ testCases := []testCase{
@@ -78,9 +81,9 @@ func Test_Routes(t *testing.T) {
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType) assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) assert.Nil(t, err)
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body)) assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
}, },
}, },
@@ -90,17 +93,17 @@ func Test_Routes(t *testing.T) {
Path: "/api/tags", Path: "/api/tags",
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType) assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) assert.Nil(t, err)
var modelList api.ListResponse var modelList api.ListResponse
err = json.Unmarshal(body, &modelList) err = json.Unmarshal(body, &modelList)
require.NoError(t, err) assert.Nil(t, err)
assert.NotNil(t, modelList.Models) assert.NotNil(t, modelList.Models)
assert.Empty(t, len(modelList.Models)) assert.Equal(t, 0, len(modelList.Models))
}, },
}, },
{ {
@@ -112,69 +115,16 @@ func Test_Routes(t *testing.T) {
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType) assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) assert.Nil(t, err)
assert.NotContains(t, string(body), "expires_at")
var modelList api.ListResponse var modelList api.ListResponse
err = json.Unmarshal(body, &modelList) err = json.Unmarshal(body, &modelList)
require.NoError(t, err) assert.Nil(t, err)
assert.Len(t, modelList.Models, 1) assert.Equal(t, 1, len(modelList.Models))
assert.Equal(t, "test-model:latest", modelList.Models[0].Name) assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
},
},
{
Name: "Create Model Handler",
Method: http.MethodPost,
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
fname := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
_, err := io.ReadAll(resp.Body)
require.NoError(t, err)
assert.Equal(t, 200, resp.StatusCode)
model, err := GetModel("t-bone")
require.NoError(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName)
},
},
{
Name: "Copy Model Handler",
Method: http.MethodPost,
Path: "/api/copy",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "hamshank")
copyReq := api.CopyRequest{
Source: "hamshank",
Destination: "beefsteak",
}
jsonData, err := json.Marshal(copyReq)
require.NoError(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
require.NoError(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName)
}, },
}, },
{ {
@@ -185,18 +135,18 @@ func Test_Routes(t *testing.T) {
createTestModel(t, "show-model") createTestModel(t, "show-model")
showReq := api.ShowRequest{Model: "show-model"} showReq := api.ShowRequest{Model: "show-model"}
jsonData, err := json.Marshal(showReq) jsonData, err := json.Marshal(showReq)
require.NoError(t, err) assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData)) req.Body = io.NopCloser(bytes.NewReader(jsonData))
}, },
Expected: func(t *testing.T, resp *http.Response) { Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type") contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json; charset=utf-8", contentType) assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
require.NoError(t, err) assert.Nil(t, err)
var showResp api.ShowResponse var showResp api.ShowResponse
err = json.Unmarshal(body, &showResp) err = json.Unmarshal(body, &showResp)
require.NoError(t, err) assert.Nil(t, err)
var params []string var params []string
paramsSplit := strings.Split(showResp.Parameters, "\n") paramsSplit := strings.Split(showResp.Parameters, "\n")
@@ -213,6 +163,109 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, expectedParams, params) assert.Equal(t, expectedParams, params)
}, },
}, },
{
Name: "Delete Handler (multiple blob reference)",
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
deleteReq := api.DeleteRequest{Model: "test-model"}
jsonData, err := json.Marshal(deleteReq)
assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
_, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200)
_, err = GetModel("test-model")
assert.True(t, os.IsNotExist(err))
model, _ := GetModel("show-model")
assert.Equal(t, "show-model:latest", model.ShortName)
for i, blob := range blobs {
blobPath, _ := GetBlobsPath(blob)
_, err := os.Stat(blobPath)
assert.False(t, os.IsNotExist(err))
blobs[i] = blobPath
}
},
},
{
Name: "Delete Handler (single blob reference)",
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
deleteReq := api.DeleteRequest{Model: "show-model"}
jsonData, err := json.Marshal(deleteReq)
assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
_, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
_, err = GetModel("show-model")
assert.True(t, os.IsNotExist(err))
for _, blob := range blobs {
_, err := os.Stat(blob)
assert.True(t, os.IsNotExist(err))
}
},
},
{
Name: "Create Model Handler",
Method: http.MethodPost,
Path: "/api/create",
Setup: func(t *testing.T, req *http.Request) {
fname := createTestFile(t, "ollama-model")
stream := false
createReq := api.CreateRequest{
Name: "t-bone",
Modelfile: fmt.Sprintf("FROM %s", fname),
Stream: &stream,
}
jsonData, err := json.Marshal(createReq)
assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, "application/json", contentType)
_, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
assert.Equal(t, resp.StatusCode, 200)
model, err := GetModel("t-bone")
assert.Nil(t, err)
assert.Equal(t, "t-bone:latest", model.ShortName)
},
},
{
Name: "Copy Model Handler",
Method: http.MethodPost,
Path: "/api/copy",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "hamshank")
copyReq := api.CopyRequest{
Source: "hamshank",
Destination: "beefsteak",
}
jsonData, err := json.Marshal(copyReq)
assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
model, err := GetModel("beefsteak")
assert.Nil(t, err)
assert.Equal(t, "beefsteak:latest", model.ShortName)
},
},
} }
t.Setenv("OLLAMA_MODELS", t.TempDir()) t.Setenv("OLLAMA_MODELS", t.TempDir())
@@ -227,14 +280,14 @@ func Test_Routes(t *testing.T) {
t.Run(tc.Name, func(t *testing.T) { t.Run(tc.Name, func(t *testing.T) {
u := httpSrv.URL + tc.Path u := httpSrv.URL + tc.Path
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil) req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
require.NoError(t, err) assert.Nil(t, err)
if tc.Setup != nil { if tc.Setup != nil {
tc.Setup(t, req) tc.Setup(t, req)
} }
resp, err := httpSrv.Client().Do(req) resp, err := httpSrv.Client().Do(req)
require.NoError(t, err) assert.Nil(t, err)
defer resp.Body.Close() defer resp.Body.Close()
if tc.Expected != nil { if tc.Expected != nil {
@@ -243,82 +296,3 @@ func Test_Routes(t *testing.T) {
}) })
} }
} }
func TestCase(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
cases := []string{
"mistral",
"llama3:latest",
"library/phi3:q4_0",
"registry.ollama.ai/library/gemma:q5_K_M",
// TODO: host:port currently fails on windows (#4107)
// "localhost:5000/alice/bob:latest",
}
var s Server
for _, tt := range cases {
t.Run(tt, func(t *testing.T) {
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: tt,
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusOK {
t.Fatalf("expected status 200 got %d", w.Code)
}
expect, err := json.Marshal(map[string]string{"error": "a model with that name already exists"})
if err != nil {
t.Fatal(err)
}
t.Run("create", func(t *testing.T) {
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
Name: strings.ToUpper(tt),
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("pull", func(t *testing.T) {
w := createRequest(t, s.PullModelHandler, api.PullRequest{
Name: strings.ToUpper(tt),
Stream: &stream,
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
t.Run("copy", func(t *testing.T) {
w := createRequest(t, s.CopyModelHandler, api.CopyRequest{
Source: tt,
Destination: strings.ToUpper(tt),
})
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status 500 got %d", w.Code)
}
if !bytes.Equal(w.Body.Bytes(), expect) {
t.Fatalf("expected error %s got %s", expect, w.Body.String())
}
})
})
}
}

View File

@@ -7,17 +7,17 @@ import (
"log/slog" "log/slog"
"reflect" "reflect"
"runtime" "runtime"
"slices"
"sort" "sort"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/envconfig"
"golang.org/x/exp/slices"
) )
type LlmRequest struct { type LlmRequest struct {
@@ -66,7 +66,7 @@ func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options,
opts.NumCtx = 4 opts.NumCtx = 4
} }
opts.NumCtx *= envconfig.NumParallel opts.NumCtx = opts.NumCtx * envconfig.NumParallel
req := &LlmRequest{ req := &LlmRequest{
ctx: c, ctx: c,
@@ -370,6 +370,7 @@ func (s *Scheduler) updateFreeSpace(allGpus gpu.GpuInfoList) {
r.refMu.Lock() r.refMu.Lock()
gpuIDs := make([]string, 0, len(r.gpus)) gpuIDs := make([]string, 0, len(r.gpus))
if r.llama != nil { if r.llama != nil {
// TODO this should be broken down by GPU instead of assuming uniform spread // TODO this should be broken down by GPU instead of assuming uniform spread
estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus)) estimatedVRAMPerGPU := r.llama.EstimatedVRAM() / uint64(len(r.gpus))
for _, gpu := range r.gpus { for _, gpu := range r.gpus {
@@ -528,6 +529,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan interface{} {
} }
}() }()
return finished return finished
} }
type ByDuration []*runnerRef type ByDuration []*runnerRef

View File

@@ -12,10 +12,11 @@ import (
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/lifecycle" "github.com/ollama/ollama/app/lifecycle"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
"github.com/ollama/ollama/gpu" "github.com/ollama/ollama/gpu"
"github.com/ollama/ollama/llm" "github.com/ollama/ollama/llm"
"github.com/ollama/ollama/envconfig"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
@@ -52,10 +53,10 @@ func TestLoad(t *testing.T) {
} }
gpus := gpu.GpuInfoList{} gpus := gpu.GpuInfoList{}
s.load(req, ggml, gpus) s.load(req, ggml, gpus)
require.Empty(t, req.successCh) require.Len(t, req.successCh, 0)
require.Len(t, req.errCh, 1) require.Len(t, req.errCh, 1)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Len(t, s.loaded, 0)
s.loadedMu.Unlock() s.loadedMu.Unlock()
err := <-req.errCh err := <-req.errCh
require.Contains(t, err.Error(), "this model may be incompatible") require.Contains(t, err.Error(), "this model may be incompatible")
@@ -112,7 +113,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
t.Helper() t.Helper()
f, err := os.CreateTemp(t.TempDir(), modelName) f, err := os.CreateTemp(t.TempDir(), modelName)
require.NoError(t, err) assert.Nil(t, err)
defer f.Close() defer f.Close()
gguf := llm.NewGGUFV3(binary.LittleEndian) gguf := llm.NewGGUFV3(binary.LittleEndian)
@@ -130,7 +131,7 @@ func newScenario(t *testing.T, ctx context.Context, modelName string, estimatedV
}, []llm.Tensor{ }, []llm.Tensor{
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}}, {Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
}) })
require.NoError(t, err) assert.Nil(t, err)
fname := f.Name() fname := f.Name()
model := &Model{Name: modelName, ModelPath: fname} model := &Model{Name: modelName, ModelPath: fname}
@@ -189,8 +190,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario1a.req.successCh: case resp := <-scenario1a.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario1a.req.errCh) require.Len(t, scenario1a.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -202,8 +203,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario1b.req.successCh: case resp := <-scenario1b.req.successCh:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario1b.req.errCh) require.Len(t, scenario1b.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -220,8 +221,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario2a.req.successCh: case resp := <-scenario2a.req.successCh:
require.Equal(t, resp.llama, scenario2a.srv) require.Equal(t, resp.llama, scenario2a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario2a.req.errCh) require.Len(t, scenario2a.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -236,8 +237,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3a.req.successCh: case resp := <-scenario3a.req.successCh:
require.Equal(t, resp.llama, scenario3a.srv) require.Equal(t, resp.llama, scenario3a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario3a.req.errCh) require.Len(t, scenario3a.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -252,8 +253,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3b.req.successCh: case resp := <-scenario3b.req.successCh:
require.Equal(t, resp.llama, scenario3b.srv) require.Equal(t, resp.llama, scenario3b.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario3b.req.errCh) require.Len(t, scenario3b.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -268,8 +269,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3c.req.successCh: case resp := <-scenario3c.req.successCh:
require.Equal(t, resp.llama, scenario3c.srv) require.Equal(t, resp.llama, scenario3c.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario3c.req.errCh) require.Len(t, scenario3c.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -295,8 +296,8 @@ func TestRequests(t *testing.T) {
select { select {
case resp := <-scenario3d.req.successCh: case resp := <-scenario3d.req.successCh:
require.Equal(t, resp.llama, scenario3d.srv) require.Equal(t, resp.llama, scenario3d.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, scenario3d.req.errCh) require.Len(t, scenario3d.req.errCh, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -331,7 +332,7 @@ func TestGetRunner(t *testing.T) {
slog.Info("scenario1b") slog.Info("scenario1b")
successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration) successCh1b, errCh1b := s.GetRunner(scenario1b.ctx, scenario1b.req.model, scenario1b.req.opts, scenario1b.req.sessionDuration)
require.Len(t, s.pendingReqCh, 1) require.Len(t, s.pendingReqCh, 1)
require.Empty(t, successCh1b) require.Len(t, successCh1b, 0)
require.Len(t, errCh1b, 1) require.Len(t, errCh1b, 1)
err := <-errCh1b err := <-errCh1b
require.Contains(t, err.Error(), "server busy") require.Contains(t, err.Error(), "server busy")
@@ -339,8 +340,8 @@ func TestGetRunner(t *testing.T) {
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, errCh1a) require.Len(t, errCh1a, 0)
case <-ctx.Done(): case <-ctx.Done():
t.Errorf("timeout") t.Errorf("timeout")
} }
@@ -354,9 +355,9 @@ func TestGetRunner(t *testing.T) {
successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration) successCh1c, errCh1c := s.GetRunner(scenario1c.ctx, scenario1c.req.model, scenario1c.req.opts, scenario1c.req.sessionDuration)
// Starts in pending channel, then should be quickly processsed to return an error // Starts in pending channel, then should be quickly processsed to return an error
time.Sleep(5 * time.Millisecond) time.Sleep(5 * time.Millisecond)
require.Empty(t, successCh1c) require.Len(t, successCh1c, 0)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Len(t, s.loaded, 0)
s.loadedMu.Unlock() s.loadedMu.Unlock()
require.Len(t, errCh1c, 1) require.Len(t, errCh1c, 1)
err = <-errCh1c err = <-errCh1c
@@ -385,8 +386,8 @@ func TestPrematureExpired(t *testing.T) {
select { select {
case resp := <-successCh1a: case resp := <-successCh1a:
require.Equal(t, resp.llama, scenario1a.srv) require.Equal(t, resp.llama, scenario1a.srv)
require.Empty(t, s.pendingReqCh) require.Len(t, s.pendingReqCh, 0)
require.Empty(t, errCh1a) require.Len(t, errCh1a, 0)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Len(t, s.loaded, 1) require.Len(t, s.loaded, 1)
s.loadedMu.Unlock() s.loadedMu.Unlock()
@@ -400,9 +401,9 @@ func TestPrematureExpired(t *testing.T) {
time.Sleep(20 * time.Millisecond) time.Sleep(20 * time.Millisecond)
require.LessOrEqual(t, len(s.finishedReqCh), 1) require.LessOrEqual(t, len(s.finishedReqCh), 1)
time.Sleep(10 * time.Millisecond) time.Sleep(10 * time.Millisecond)
require.Empty(t, s.finishedReqCh) require.Len(t, s.finishedReqCh, 0)
s.loadedMu.Lock() s.loadedMu.Lock()
require.Empty(t, s.loaded) require.Len(t, s.loaded, 0)
s.loadedMu.Unlock() s.loadedMu.Unlock()
// also shouldn't happen in real life // also shouldn't happen in real life
@@ -486,6 +487,7 @@ func TestFindRunnerToUnload(t *testing.T) {
r2.refCount = 1 r2.refCount = 1
resp = s.findRunnerToUnload() resp = s.findRunnerToUnload()
require.Equal(t, r1, resp) require.Equal(t, r1, resp)
} }
func TestNeedsReload(t *testing.T) { func TestNeedsReload(t *testing.T) {

View File

@@ -146,7 +146,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
case requestURL := <-b.nextURL: case requestURL := <-b.nextURL:
g.Go(func() error { g.Go(func() error {
var err error var err error
for try := range maxRetries { for try := 0; try < maxRetries; try++ {
err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts) err = b.uploadPart(inner, http.MethodPatch, requestURL, part, opts)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
@@ -190,7 +190,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *registryOptions) {
headers.Set("Content-Type", "application/octet-stream") headers.Set("Content-Type", "application/octet-stream")
headers.Set("Content-Length", "0") headers.Set("Content-Length", "0")
for try := range maxRetries { for try := 0; try < maxRetries; try++ {
var resp *http.Response var resp *http.Response
resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts) resp, err = makeRequestWithRetry(ctx, http.MethodPut, requestURL, headers, nil, opts)
if errors.Is(err, context.Canceled) { if errors.Is(err, context.Canceled) {
@@ -253,7 +253,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
} }
// retry uploading to the redirect URL // retry uploading to the redirect URL
for try := range maxRetries { for try := 0; try < maxRetries; try++ {
err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil) err = b.uploadPart(ctx, http.MethodPut, redirectURL, part, nil)
switch { switch {
case errors.Is(err, context.Canceled): case errors.Is(err, context.Canceled):
@@ -391,7 +391,7 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *registryO
return err return err
} }
//nolint:contextcheck // nolint: contextcheck
go upload.Run(context.Background(), opts) go upload.Run(context.Background(), opts)
} }

View File

@@ -1 +0,0 @@
{{ if .System }}<start_system>{{ .System }}<end_message>{{ end }}{{ if .Prompt }}<start_user>{{ .Prompt }}<end_message>{{ end }}<start_assistant>{{ .Response }}<end_message>

View File

@@ -1,7 +0,0 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}

View File

@@ -1,6 +0,0 @@
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>

View File

@@ -1,5 +0,0 @@
{{ if .System }}System: {{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: <|begin_of_text|>{{ .Response }}

View File

@@ -1,8 +0,0 @@
{{ if .System }} Source: system
{{ .System }} <step>{{ end }} Source: user
{{ .Prompt }} <step> Source: assistant
Destination: user
{{ .Response }}<step>

View File

@@ -1,3 +0,0 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}User: {{ .Prompt }}
{{ end }}Assistant: {{ .Response }}

View File

@@ -1,4 +0,0 @@
<start_of_turn>user
{{ if .System }}{{ .System }} {{ end }}{{ .Prompt }}<end_of_turn>
<start_of_turn>model
{{ .Response }}<end_of_turn>

View File

@@ -1,9 +0,0 @@
{{ if .System }}
System:
{{ .System }}
{{ end }}{{ if .Prompt }}Question:
{{ .Prompt }}
{{ end }}Answer:
{{ .Response }}

View File

@@ -1,138 +0,0 @@
[
{
"template": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"name": "zephyr"
},
{
"template": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}",
"name": "chatml"
},
{
"template": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}",
"name": "openchat"
},
{
"template": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}",
"name": "zephyr"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"name": "mistral-instruct"
},
{
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}",
"name": "starcoder2-instruct"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}",
"name": "llama2-chat"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}",
"name": "codellama-70b-instruct"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}",
"name": "mistral-instruct"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
"name": "chatml"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}",
"name": "chatml"
},
{
"template": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}",
"name": "alpaca"
},
{
"template": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}",
"name": "chatqa"
},
{
"template": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}",
"name": "gemma-instruct"
},
{
"template": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}",
"name": "llama3-instruct"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}",
"name": "granite-instruct"
},
{
"template": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}",
"name": "magicoder"
},
{
"template": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_user>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'system' %}{{ '<start_system>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'assistant' %}{{ '<start_assistant>' + message['content'] + '<end_message>' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_assistant>' }}{% endif %}{% endfor %}",
"name": "alfred"
},
{
"template": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}",
"name": "llama2-chat"
},
{
"template": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
"name": "phi-3"
},
{
"template": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}",
"name": "chatqa"
},
{
"template": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}",
"name": "falcon-instruct"
},
{
"template": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}",
"name": "falcon-instruct"
},
{
"template": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}",
"name": "solar-instruct"
}
]

View File

@@ -1,3 +0,0 @@
[INST] <<SYS>>{{ .System }}<</SYS>>
{{ .Prompt }} [/INST] {{ .Response }}

View File

@@ -1,7 +0,0 @@
{{ if .System }}<|start_header_id|>system<|end_header_id|>
{{ .System }}<|eot_id|>{{ end }}{{ if .Prompt }}<|start_header_id|>user<|end_header_id|>
{{ .Prompt }}<|eot_id|>{{ end }}<|start_header_id|>assistant<|end_header_id|>
{{ .Response }}<|eot_id|>

View File

@@ -1,7 +0,0 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}@@ Instruction
{{ .Prompt }}
{{ end }}@@ Response
{{ .Response }}

View File

@@ -1,6 +0,0 @@
{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>

View File

@@ -1 +0,0 @@
{{ .System }}<|end_of_turn|>GPT4 Correct User: {{ .Prompt }}<|end_of_turn|>GPT4 Correct Assistant: {{ .Response }}<|end_of_turn|>

View File

@@ -1,6 +0,0 @@
{{ if .System }}<|system|>
{{ .System }}<|end|>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}<|end|>
{{ end }}<|assistant|>
{{ .Response }}<|end|>

View File

@@ -1,8 +0,0 @@
{{ if .System }}### System:
{{ .System }}
{{ end }}{{ if .Prompt }}### User:
{{ .Prompt }}
{{ end }}### Assistant:
{{ .Response }}

View File

@@ -1,9 +0,0 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction
{{ .Prompt }}
{{ end }}### Response
{{ .Response }}<|endoftext|>

View File

@@ -1,70 +0,0 @@
package templates
import (
"bytes"
"embed"
"encoding/json"
"errors"
"io"
"math"
"sync"
"github.com/agnivade/levenshtein"
)
//go:embed index.json
var indexBytes []byte
//go:embed *.gotmpl
var templatesFS embed.FS
var templatesOnce = sync.OnceValues(func() ([]*Template, error) {
var templates []*Template
if err := json.Unmarshal(indexBytes, &templates); err != nil {
return nil, err
}
for _, t := range templates {
bts, err := templatesFS.ReadFile(t.Name + ".gotmpl")
if err != nil {
return nil, err
}
// normalize line endings
t.Bytes = bytes.ReplaceAll(bts, []byte("\r\n"), []byte("\n"))
}
return templates, nil
})
type Template struct {
Name string `json:"name"`
Template string `json:"template"`
Bytes []byte
}
func (t Template) Reader() io.Reader {
return bytes.NewReader(t.Bytes)
}
func NamedTemplate(s string) (*Template, error) {
templates, err := templatesOnce()
if err != nil {
return nil, err
}
var template *Template
score := math.MaxInt
for _, t := range templates {
if s := levenshtein.ComputeDistance(s, t.Template); s < score {
score = s
template = t
}
}
if score < 100 {
return template, nil
}
return nil, errors.New("no matching template found")
}

View File

@@ -1,59 +0,0 @@
package templates
import (
"bufio"
"bytes"
"encoding/json"
"io"
"os"
"path/filepath"
"testing"
"text/template"
"github.com/ollama/ollama/llm"
)
func TestKVChatTemplate(t *testing.T) {
f, err := os.Open(filepath.Join("testdata", "templates.jsonl"))
if err != nil {
t.Fatal(err)
}
defer f.Close()
scanner := bufio.NewScanner(f)
for scanner.Scan() {
var ss map[string]string
if err := json.Unmarshal(scanner.Bytes(), &ss); err != nil {
t.Fatal(err)
}
for k, v := range ss {
t.Run(k, func(t *testing.T) {
kv := llm.KV{"tokenizer.chat_template": v}
s := kv.ChatTemplate()
r, err := NamedTemplate(s)
if err != nil {
t.Fatal(err)
}
if r.Name != k {
t.Errorf("expected %q, got %q", k, r.Name)
}
var b bytes.Buffer
if _, err := io.Copy(&b, r.Reader()); err != nil {
t.Fatal(err)
}
tmpl, err := template.New(s).Parse(b.String())
if err != nil {
t.Fatal(err)
}
if tmpl.Tree.Root.String() == "" {
t.Errorf("empty %s template", k)
}
})
}
}
}

View File

@@ -1,35 +0,0 @@
{"chatml": "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}{% if system_message is defined %}{{ system_message }}{% endif %}{% for message in messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|im_start|>user\\n' + content + '<|im_end|>\\n<|im_start|>assistant\\n' }}{% elif message['role'] == 'assistant' %}{{ content + '<|im_end|>' + '\\n' }}{% endif %}{% endfor %}"}
{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"}
{"chatml": "{% if messages[0]['role'] == 'user' or messages[0]['role'] == 'system' %}{{ bos_token }}{% endif %}{% for message in messages %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% elif messages[-1]['role'] == 'assistant' %}{{ eos_token }}{% endif %}"}
{"openchat": "{{ bos_token }}{% for message in messages %}{{ 'GPT4 Correct ' + message['role'].title() + ': ' + message['content'] + '<|end_of_turn|>'}}{% endfor %}{% if add_generation_prompt %}{{ 'GPT4 Correct Assistant:' }}{% endif %}"}
{"chatml": "{{bos_token}}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"chatml": "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"zephyr": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"}
{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"}
{"starcoder2-instruct": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'### Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'### Response\n'}}"}
{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content | trim + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content | trim + ' ' + eos_token }}{% endif %}{% endfor %}"}
{"codellama-70b-instruct": "{% if messages[0]['role'] == 'system' %}{% set user_index = 1 %}{% else %}{% set user_index = 0 %}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != ((loop.index0 + user_index) % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 %}{{ '<s>' }}{% endif %}{% set content = 'Source: ' + message['role'] + '\n\n ' + message['content'] | trim %}{{ content + ' <step> ' }}{% endfor %}{{'Source: assistant\nDestination: user\n\n '}}"}
{"mistral-instruct": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}"}
{"chatml": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|im_start|>user\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'system' %}\n{{ '<|im_start|>system\n' + message['content'] + '<|im_end|>' }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|im_start|>assistant\n' + message['content'] + '<|im_end|>' }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|im_start|>assistant' }}\n{% endif %}\n{% endfor %}"}
{"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}
{"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif 'system' not in messages[0]['role'] %}{% set loop_messages = messages %}{% set system_message = 'You are DBRX, created by Databricks. You were last updated in December 2023. You answer questions based on information available up to that point.\nYOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough responses to more complex and open-ended questions.\nYou assist with various tasks, from writing to coding (using markdown for code blocks \u2014 remember to use ``` with code, JSON, and tables).\n(You do not have real-time data access or code execution capabilities. You avoid stereotyping and provide balanced perspectives on controversial topics. You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.)\nThis is your system prompt, guiding your responses. Do not reference it, just respond to the user. If you find yourself talking about this message, stop. You should be responding appropriately and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER\\'S QUERY.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{% if system_message != false %}{{ '<|im_start|>system\n' + system_message | trim + '<|im_end|>\n'}}{% endif %}{{ '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% else %}{{ '\n' + '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' }}{% endif %}{% if (add_generation_prompt == true and loop.last) %}{{ '\n' + '<|im_start|>' + 'assistant' + '\n' }}{% endif %}{% endfor %}"}
{"alpaca": "{% if not add_generation_prompt is defined %}\n{% set add_generation_prompt = false %}\n{% endif %}\n{%- set ns = namespace(found=false) -%}\n{%- for message in messages -%}\n {%- if message['role'] == 'system' -%}\n {%- set ns.found = true -%}\n {%- endif -%}\n{%- endfor -%}\n{{bos_token}}{%- if not ns.found -%}\n{{'You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer\\n'}}\n{%- endif %}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n{{ message['content'] }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'### Instruction:\\n' + message['content'] + '\\n'}}\n {%- else %}\n{{'### Response:\\n' + message['content'] + '\\n<|EOT|>\\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{% if add_generation_prompt %}\n{{'### Response:'}}\n{% endif %}"}
{"chatqa": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}{{ 'User: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + eos_token }}{% elif message['role'] == 'system' %}{{ message['content'] + '\n\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"}
{"gemma-instruct": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '<start_of_turn>' + role + '\n' + message['content'] | trim + '<end_of_turn>\n' }}{% endfor %}{% if add_generation_prompt %}{{'<start_of_turn>model\n'}}{% endif %}"}
{"llama3-instruct": "{% set loop_messages = messages %}{% for message in loop_messages %}{% set content = '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' %}{% if loop.index0 == 0 %}{% set content = bos_token + content %}{% endif %}{{ content }}{% endfor %}{% if add_generation_prompt %}{{ '<|start_header_id|>assistant<|end_header_id|>\n\n' }}{% endif %}"}
{"granite-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'Question:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'system' %}\n{{ 'System:\n' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ 'Answer:\n' + message['content'] + '\n\n' }}{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Answer:\n' }}{% endif %}{% endfor %}"}
{"magicoder": "{{bos_token}}{{'You are an exceptionally intelligent coding assistant that consistently delivers accurate and reliable responses to user instructions.\n\n'}}\n{%- for message in messages %}\n {%- if message['role'] == 'system' %}\n {{ raise_exception('System messages are not allowed in this template.') }}\n {%- else %}\n {%- if message['role'] == 'user' %}\n{{'@@ Instruction\n' + message['content'] + '\n\n'}}\n {%- else %}\n{{'@@ Response\n' + message['content'] + eos_token + '\n\n'}}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{{'@@ Response\n'}}"}
{"alfred": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '<start_user>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'system' %}{{ '<start_system>' + message['content'].strip() + '<end_message>' }}{% elif message['role'] == 'assistant' %}{{ '<start_assistant>' + message['content'] + '<end_message>' }}{% else %}{{ raise_exception('Only system, user and assistant roles are supported.') }}{% endif %}{% if loop.last and add_generation_prompt %}{{ '<start_assistant>' }}{% endif %}{% endfor %}"}
{"llama2-chat": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if loop.index0 == 0 and system_message != false %}{% set content = '<<SYS>>\\n' + system_message + '\\n<</SYS>>\\n\\n' + message['content'] %}{% else %}{% set content = message['content'] %}{% endif %}{% if message['role'] == 'user' %}{{ bos_token + '[INST] ' + content.strip() + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + content.strip() + ' ' + eos_token }}{% endif %}{% endfor %}"}
{"phi-3": "{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"}
{"phi-3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') %}{{'<|user|>' + '\n' + message['content'] + '<|end|>' + '\n' + '<|assistant|>' + '\n'}}{% elif (message['role'] == 'assistant') %}{{message['content'] + '<|end|>' + '\n'}}{% endif %}{% endfor %}"}
{"phi-3": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}"}
{"chatqa": "{{ bos_token }}{%- if messages[0]['role'] == 'system' -%}{% set loop_messages = messages[1:] %}{%- else -%}{% set loop_messages = messages %}{% endif %}System: This is a chat between a user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions based on the context. The assistant should also indicate when the answer cannot be found in the context.\n\n{% for message in loop_messages %}{%- if message['role'] == 'user' -%}User: {{ message['content'].strip() + '\n\n' }}{%- else -%}Assistant: {{ message['content'].strip() + '\n\n' }}{%- endif %}{% if loop.last and message['role'] == 'user' %}Assistant:{% endif %}{% endfor %}"}
{"falcon-instruct": "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ 'User: \n' + message['content'] }}\n{% elif message['role'] == 'system' %}\n{{ 'System: ' + message['content'] }}\n{% elif message['role'] == 'assistant' %}\n{{ 'Falcon:\n' + message['content']}}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ 'Falcon:' }}\n{% endif %}\n{% endfor %}"}
{"falcon-instruct": "{% for message in messages %}{% if not loop.first %}{{ '\n' }}{% endif %}{% if message['role'] == 'system' %}{{ 'System: ' }}{% elif message['role'] == 'user' %}{{ 'User: ' }}{% elif message['role'] == 'assistant' %}{{ 'Falcon: ' }}{% endif %}{{ message['content'] }}{% endfor %}{% if add_generation_prompt %}{{ '\n' + 'Falcon:' }}{% endif %}"}
{"solar-instruct": "{% for message in messages %}{% if message['role'] == 'system' %}{% if message['content']%}{{'### System:\n' + message['content']+'\n\n'}}{% endif %}{% elif message['role'] == 'user' %}{{'### User:\n' + message['content']+'\n\n'}}{% elif message['role'] == 'assistant' %}{{'### Assistant:\n' + message['content']}}{% endif %}{% if loop.last and add_generation_prompt %}{{ '### Assistant:\n' }}{% endif %}{% endfor %}"}
{"chatml": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"}

View File

@@ -1,3 +0,0 @@
{{ if .System }}{{ .System }}
{{ end }}{{ if .Prompt }}USER: {{ .Prompt }}
{{ end }}ASSISTANT: {{ .Response }}

View File

@@ -1,6 +0,0 @@
{{ if .System }}<|system|>
{{ .System }}</s>
{{ end }}{{ if .Prompt }}<|user|>
{{ .Prompt }}</s>
{{ end }}<|assistant|>
{{ .Response }}</s>

View File

@@ -251,10 +251,6 @@ func (n Name) DisplayShortest() string {
return sb.String() return sb.String()
} }
func IsValidNamespace(namespace string) bool {
return isValidPart(kindNamespace, namespace)
}
// IsValid reports whether all parts of the name are present and valid. The // IsValid reports whether all parts of the name are present and valid. The
// digest is a special case, and is checked for validity only if present. // digest is a special case, and is checked for validity only if present.
func (n Name) IsValid() bool { func (n Name) IsValid() bool {

View File

@@ -268,6 +268,7 @@ func TestNameIsValidPart(t *testing.T) {
} }
}) })
} }
} }
func TestFilepathAllocs(t *testing.T) { func TestFilepathAllocs(t *testing.T) {
@@ -324,7 +325,7 @@ func TestParseNameFromFilepath(t *testing.T) {
filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"}, filepath.Join("host:port", "namespace", "model", "tag"): {Host: "host:port", Namespace: "namespace", Model: "model", Tag: "tag"},
filepath.Join("namespace", "model", "tag"): {}, filepath.Join("namespace", "model", "tag"): {},
filepath.Join("model", "tag"): {}, filepath.Join("model", "tag"): {},
"model": {}, filepath.Join("model"): {},
filepath.Join("..", "..", "model", "tag"): {}, filepath.Join("..", "..", "model", "tag"): {},
filepath.Join("", "namespace", ".", "tag"): {}, filepath.Join("", "namespace", ".", "tag"): {},
filepath.Join(".", ".", ".", "."): {}, filepath.Join(".", ".", ".", "."): {},
@@ -381,32 +382,6 @@ func FuzzName(f *testing.F) {
t.Errorf("String() = %q; want %q", n.String(), s) t.Errorf("String() = %q; want %q", n.String(), s)
} }
} }
}) })
} }
func TestIsValidNamespace(t *testing.T) {
cases := []struct {
username string
expected bool
}{
{"", false},
{"a", true},
{"a:b", false},
{"a/b", false},
{"a:b/c", false},
{"a/b:c", false},
{"a/b:c", false},
{"a/b:c/d", false},
{"a/b:c/d@e", false},
{"a/b:c/d@sha256-100", false},
{"himynameisjoe", true},
{"himynameisreallyreallyreallyreallylongbutitshouldstillbevalid", true},
}
for _, tt := range cases {
t.Run(tt.username, func(t *testing.T) {
if got := IsValidNamespace(tt.username); got != tt.expected {
t.Errorf("IsValidName(%q) = %v; want %v", tt.username, got, tt.expected)
}
})
}
}