Compare commits
3 Commits
v0.1.45-rc
...
jyan/forma
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5dc5a295bf | ||
|
|
e21e6b2a33 | ||
|
|
a240ea3367 |
6
.github/workflows/test.yaml
vendored
6
.github/workflows/test.yaml
vendored
@@ -124,7 +124,7 @@ jobs:
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '6.1.1'
|
||||
- '6.0.2'
|
||||
runs-on: linux
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
@@ -269,9 +269,9 @@ jobs:
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- uses: golangci/golangci-lint-action@v6
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s -v ${{ startsWith(matrix.os, 'windows-') && '' || '--disable gofmt --disable goimports' }}
|
||||
args: --timeout 8m0s -v
|
||||
test:
|
||||
strategy:
|
||||
matrix:
|
||||
|
||||
@@ -9,26 +9,9 @@ linters:
|
||||
- contextcheck
|
||||
- exportloopref
|
||||
- gocheckcompilerdirectives
|
||||
# conditionally enable this on linux/macos
|
||||
# FIXME: for some reason this errors on windows
|
||||
# - gofmt
|
||||
# - goimports
|
||||
- intrange
|
||||
- misspell
|
||||
- nilerr
|
||||
- nolintlint
|
||||
- nosprintfhostport
|
||||
- testifylint
|
||||
- unconvert
|
||||
- unused
|
||||
- wastedassign
|
||||
- whitespace
|
||||
- usestdlibvars
|
||||
severity:
|
||||
default-severity: error
|
||||
rules:
|
||||
- linters:
|
||||
- gofmt
|
||||
- goimports
|
||||
- intrange
|
||||
- usestdlibvars
|
||||
severity: info
|
||||
|
||||
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
ARG ROCM_VERSION=6.1.1
|
||||
ARG ROCM_VERSION=6.0.2
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
|
||||
11
README.md
11
README.md
@@ -6,7 +6,7 @@
|
||||
|
||||
[](https://discord.gg/ollama)
|
||||
|
||||
Get up and running with large language models.
|
||||
Get up and running with large language models locally.
|
||||
|
||||
### macOS
|
||||
|
||||
@@ -285,7 +285,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [macai](https://github.com/Renset/macai) (macOS client for Ollama, ChatGPT, and other compatible API back-ends)
|
||||
- [Olpaka](https://github.com/Otacon/olpaka) (User-friendly Flutter Web App for Ollama)
|
||||
- [OllamaSpring](https://github.com/CrazyNeil/OllamaSpring) (Ollama Client for macOS)
|
||||
- [LLocal.in](https://github.com/kartikm7/llocal) (Easy to use Electron Desktop Client for Ollama)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -308,7 +307,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [ShellOracle](https://github.com/djcopley/ShellOracle)
|
||||
- [tlm](https://github.com/yusufcanb/tlm)
|
||||
- [podman-ollama](https://github.com/ericcurtin/podman-ollama)
|
||||
- [gollama](https://github.com/sammcj/gollama)
|
||||
|
||||
### Database
|
||||
|
||||
@@ -326,13 +324,11 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [LangChain](https://python.langchain.com/docs/integrations/llms/ollama) and [LangChain.js](https://js.langchain.com/docs/modules/model_io/models/llms/integrations/ollama) with [example](https://js.langchain.com/docs/use_cases/question_answering/local_retrieval_qa)
|
||||
- [LangChainGo](https://github.com/tmc/langchaingo/) with [example](https://github.com/tmc/langchaingo/tree/main/examples/ollama-completion-example)
|
||||
- [LangChain4j](https://github.com/langchain4j/langchain4j) with [example](https://github.com/langchain4j/langchain4j-examples/tree/main/ollama-examples/src/main/java)
|
||||
- [LangChainRust](https://github.com/Abraxas-365/langchain-rust) with [example](https://github.com/Abraxas-365/langchain-rust/blob/main/examples/llm_ollama.rs)
|
||||
- [LlamaIndex](https://gpt-index.readthedocs.io/en/stable/examples/llm/ollama.html)
|
||||
- [LiteLLM](https://github.com/BerriAI/litellm)
|
||||
- [OllamaSharp for .NET](https://github.com/awaescher/OllamaSharp)
|
||||
- [Ollama for Ruby](https://github.com/gbaptista/ollama-ai)
|
||||
- [Ollama-rs for Rust](https://github.com/pepperoni21/ollama-rs)
|
||||
- [Ollama-hpp for C++](https://github.com/jmont-dev/ollama-hpp)
|
||||
- [Ollama4j for Java](https://github.com/amithkoujalgi/ollama4j)
|
||||
- [ModelFusion Typescript Library](https://modelfusion.dev/integration/model-provider/ollama)
|
||||
- [OllamaKit for Swift](https://github.com/kevinhermawan/OllamaKit)
|
||||
@@ -350,7 +346,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Portkey](https://portkey.ai/docs/welcome/integration-guides/ollama)
|
||||
- [PromptingTools.jl](https://github.com/svilupp/PromptingTools.jl) with an [example](https://svilupp.github.io/PromptingTools.jl/dev/examples/working_with_ollama)
|
||||
- [LlamaScript](https://github.com/Project-Llama/llamascript)
|
||||
|
||||
### Mobile
|
||||
|
||||
- [Enchanted](https://github.com/AugustDev/enchanted)
|
||||
@@ -383,9 +378,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
- [Discord-Ollama Chat Bot](https://github.com/kevinthedang/discord-ollama) (Generalized TypeScript Discord Bot w/ Tuning Documentation)
|
||||
- [Discord AI chat/moderation bot](https://github.com/rapmd73/Companion) Chat/moderation bot written in python. Uses Ollama to create personalities.
|
||||
- [Headless Ollama](https://github.com/nischalj10/headless-ollama) (Scripts to automatically install ollama client & models on any OS for apps that depends on ollama server)
|
||||
|
||||
### Supported backends
|
||||
|
||||
### Supported backends
|
||||
- [llama.cpp](https://github.com/ggerganov/llama.cpp) project founded by Georgi Gerganov.
|
||||
|
||||
|
||||
@@ -23,9 +23,11 @@ import (
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -63,7 +65,10 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
ollamaHost := envconfig.Host
|
||||
ollamaHost, err := GetOllamaHost()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Client{
|
||||
base: &url.URL{
|
||||
@@ -74,6 +79,52 @@ func ClientFromEnvironment() (*Client, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
type OllamaHost struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
}
|
||||
|
||||
func GetOllamaHost() (OllamaHost, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
hostVar := os.Getenv("OLLAMA_HOST")
|
||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||
|
||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", hostVar
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
defaultPort = "443"
|
||||
}
|
||||
|
||||
// trim trailing slashes
|
||||
hostport = strings.TrimRight(hostport, "/")
|
||||
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", defaultPort
|
||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
} else if hostport != "" {
|
||||
host = hostport
|
||||
}
|
||||
}
|
||||
|
||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||
return OllamaHost{}, ErrInvalidHostPort
|
||||
}
|
||||
|
||||
return OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewClient(base *url.URL, http *http.Client) *Client {
|
||||
return &Client{
|
||||
base: base,
|
||||
@@ -304,8 +355,8 @@ func (c *Client) List(ctx context.Context) (*ListResponse, error) {
|
||||
}
|
||||
|
||||
// List running models.
|
||||
func (c *Client) ListRunning(ctx context.Context) (*ProcessResponse, error) {
|
||||
var lr ProcessResponse
|
||||
func (c *Client) ListRunning(ctx context.Context) (*ListResponse, error) {
|
||||
var lr ListResponse
|
||||
if err := c.do(ctx, http.MethodGet, "/api/ps", nil, &lr); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
@@ -33,7 +35,6 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
for k, v := range testCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
envconfig.LoadConfig()
|
||||
|
||||
client, err := ClientFromEnvironment()
|
||||
if err != v.err {
|
||||
@@ -45,4 +46,40 @@ func TestClientFromEnvironment(t *testing.T) {
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
hostTestCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: ":1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||
"zero port": {value: ":0", expect: ":0"},
|
||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||
}
|
||||
|
||||
for k, v := range hostTestCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
|
||||
oh, err := GetOllamaHost()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
31
api/types.go
31
api/types.go
@@ -2,6 +2,7 @@ package api
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"math"
|
||||
@@ -281,33 +282,19 @@ type PushRequest struct {
|
||||
|
||||
// ListResponse is the response from [Client.List].
|
||||
type ListResponse struct {
|
||||
Models []ListModelResponse `json:"models"`
|
||||
Models []ModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// ProcessResponse is the response from [Client.Process].
|
||||
type ProcessResponse struct {
|
||||
Models []ProcessModelResponse `json:"models"`
|
||||
}
|
||||
|
||||
// ListModelResponse is a single model description in [ListResponse].
|
||||
type ListModelResponse struct {
|
||||
// ModelResponse is a single model description in [ListResponse].
|
||||
type ModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
ModifiedAt time.Time `json:"modified_at,omitempty"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
}
|
||||
|
||||
// ProcessModelResponse is a single model description in [ProcessResponse].
|
||||
type ProcessModelResponse struct {
|
||||
Name string `json:"name"`
|
||||
Model string `json:"model"`
|
||||
Size int64 `json:"size"`
|
||||
Digest string `json:"digest"`
|
||||
Details ModelDetails `json:"details,omitempty"`
|
||||
ExpiresAt time.Time `json:"expires_at"`
|
||||
SizeVRAM int64 `json:"size_vram"`
|
||||
ExpiresAt time.Time `json:"expires_at,omitempty"`
|
||||
SizeVRAM int64 `json:"size_vram,omitempty"`
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
@@ -319,7 +306,7 @@ type GenerateResponse struct {
|
||||
// Model is the model name that generated the response.
|
||||
Model string `json:"model"`
|
||||
|
||||
// CreatedAt is the timestamp of the response.
|
||||
//CreatedAt is the timestamp of the response.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
|
||||
// Response is the textual response itself.
|
||||
@@ -376,6 +363,8 @@ func (m *Metrics) Summary() {
|
||||
}
|
||||
}
|
||||
|
||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||
|
||||
func (opts *Options) FromMap(m map[string]interface{}) error {
|
||||
valueOpts := reflect.ValueOf(opts).Elem() // names of the fields in the options struct
|
||||
typeOpts := reflect.TypeOf(opts).Elem() // types of the fields in the options struct
|
||||
|
||||
@@ -72,13 +72,13 @@ func TestDurationMarshalUnmarshal(t *testing.T) {
|
||||
},
|
||||
{
|
||||
"positive duration",
|
||||
42 * time.Second,
|
||||
42 * time.Second,
|
||||
time.Duration(42 * time.Second),
|
||||
time.Duration(42 * time.Second),
|
||||
},
|
||||
{
|
||||
"another positive duration",
|
||||
42 * time.Minute,
|
||||
42 * time.Minute,
|
||||
time.Duration(42 * time.Minute),
|
||||
time.Duration(42 * time.Minute),
|
||||
},
|
||||
{
|
||||
"zero duration",
|
||||
|
||||
@@ -69,6 +69,7 @@ func init() {
|
||||
slog.Error(fmt.Sprintf("create ollama dir %s: %v", AppDataDir, err))
|
||||
}
|
||||
}
|
||||
|
||||
} else if runtime.GOOS == "darwin" {
|
||||
// TODO
|
||||
AppName += ".app"
|
||||
|
||||
@@ -15,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
func getCLIFullPath(command string) string {
|
||||
var cmdPath string
|
||||
cmdPath := ""
|
||||
appExe, err := os.Executable()
|
||||
if err == nil {
|
||||
cmdPath = filepath.Join(filepath.Dir(appExe), command)
|
||||
@@ -65,6 +65,7 @@ func start(ctx context.Context, command string) (*exec.Cmd, error) {
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, fmt.Errorf("stat ollama server log dir %s: %v", logDir, err)
|
||||
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
|
||||
@@ -24,8 +24,7 @@ func terminate(cmd *exec.Cmd) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
//nolint:errcheck
|
||||
defer dll.Release()
|
||||
defer dll.Release() // nolint: errcheck
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
@@ -74,8 +73,7 @@ func isProcessExited(pid int) (bool, error) {
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open process: %v", err)
|
||||
}
|
||||
//nolint:errcheck
|
||||
defer windows.CloseHandle(hProcess)
|
||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
||||
|
||||
var exitCode uint32
|
||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||
|
||||
@@ -78,7 +78,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNoContent {
|
||||
if resp.StatusCode == 204 {
|
||||
slog.Debug("check update response 204 (current version is up to date)")
|
||||
return false, updateResp
|
||||
}
|
||||
@@ -87,7 +87,7 @@ func IsNewReleaseAvailable(ctx context.Context) (bool, UpdateResponse) {
|
||||
slog.Warn(fmt.Sprintf("failed to read body response: %s", err))
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode != 200 {
|
||||
slog.Info(fmt.Sprintf("check update error %d - %.96s", resp.StatusCode, string(body)))
|
||||
return false, updateResp
|
||||
}
|
||||
@@ -114,7 +114,7 @@ func DownloadNewRelease(ctx context.Context, updateResp UpdateResponse) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("error checking update: %w", err)
|
||||
}
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if resp.StatusCode != 200 {
|
||||
return fmt.Errorf("unexpected status attempting to download update %d", resp.StatusCode)
|
||||
}
|
||||
resp.Body.Close()
|
||||
|
||||
@@ -29,6 +29,7 @@ func GetID() string {
|
||||
initStore()
|
||||
}
|
||||
return store.ID
|
||||
|
||||
}
|
||||
|
||||
func GetFirstTimeRun() bool {
|
||||
|
||||
@@ -47,6 +47,7 @@ func nativeLoop() {
|
||||
default:
|
||||
pTranslateMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||
pDispatchMessage.Call(uintptr(unsafe.Pointer(m))) //nolint:errcheck
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -159,8 +160,8 @@ func (t *winTray) wndProc(hWnd windows.Handle, message uint32, wParam, lParam ui
|
||||
lResult, _, _ = pDefWindowProc.Call(
|
||||
uintptr(hWnd),
|
||||
uintptr(message),
|
||||
wParam,
|
||||
lParam,
|
||||
uintptr(wParam),
|
||||
uintptr(lParam),
|
||||
)
|
||||
}
|
||||
return
|
||||
|
||||
@@ -186,7 +186,7 @@ func (t *winTray) initInstance() error {
|
||||
t.muNID.Lock()
|
||||
defer t.muNID.Unlock()
|
||||
t.nid = ¬ifyIconData{
|
||||
Wnd: t.window,
|
||||
Wnd: windows.Handle(t.window),
|
||||
ID: 100,
|
||||
Flags: NIF_MESSAGE,
|
||||
CallbackMessage: t.wmSystrayMessage,
|
||||
@@ -197,6 +197,7 @@ func (t *winTray) initInstance() error {
|
||||
}
|
||||
|
||||
func (t *winTray) createMenu() error {
|
||||
|
||||
menuHandle, _, err := pCreatePopupMenu.Call()
|
||||
if menuHandle == 0 {
|
||||
return err
|
||||
@@ -245,7 +246,7 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
||||
mi := menuItemInfo{
|
||||
Mask: MIIM_FTYPE | MIIM_STRING | MIIM_ID | MIIM_STATE,
|
||||
Type: MFT_STRING,
|
||||
ID: menuItemId,
|
||||
ID: uint32(menuItemId),
|
||||
TypeData: titlePtr,
|
||||
Cch: uint32(len(title)),
|
||||
}
|
||||
@@ -301,10 +302,11 @@ func (t *winTray) addOrUpdateMenuItem(menuItemId uint32, parentId uint32, title
|
||||
}
|
||||
|
||||
func (t *winTray) addSeparatorMenuItem(menuItemId, parentId uint32) error {
|
||||
|
||||
mi := menuItemInfo{
|
||||
Mask: MIIM_FTYPE | MIIM_ID | MIIM_STATE,
|
||||
Type: MFT_SEPARATOR,
|
||||
ID: menuItemId,
|
||||
ID: uint32(menuItemId),
|
||||
}
|
||||
|
||||
mi.Size = uint32(unsafe.Sizeof(mi))
|
||||
@@ -424,6 +426,7 @@ func iconBytesToFilePath(iconBytes []byte) (string, error) {
|
||||
// Loads an image from file and shows it in tray.
|
||||
// Shell_NotifyIcon: https://msdn.microsoft.com/en-us/library/windows/desktop/bb762159(v=vs.85).aspx
|
||||
func (t *winTray) setIcon(src string) error {
|
||||
|
||||
h, err := t.loadIconFrom(src)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -441,6 +444,7 @@ func (t *winTray) setIcon(src string) error {
|
||||
// Loads an image from file to be shown in tray or menu item.
|
||||
// LoadImage: https://msdn.microsoft.com/en-us/library/windows/desktop/ms648045(v=vs.85).aspx
|
||||
func (t *winTray) loadIconFrom(src string) (windows.Handle, error) {
|
||||
|
||||
// Save and reuse handles of loaded images
|
||||
t.muLoadedImages.RLock()
|
||||
h, ok := t.loadedImages[src]
|
||||
|
||||
29
cmd/cmd.go
29
cmd/cmd.go
@@ -20,7 +20,6 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"syscall"
|
||||
"time"
|
||||
@@ -30,6 +29,7 @@ import (
|
||||
"github.com/olekukonko/tablewriter"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/term"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -746,6 +746,7 @@ func displayResponse(content string, wordWrap bool, state *displayResponseState)
|
||||
if wordWrap && termWidth >= 10 {
|
||||
for _, ch := range content {
|
||||
if state.lineLength+1 > termWidth-5 {
|
||||
|
||||
if runewidth.StringWidth(state.wordBuffer) > termWidth-10 {
|
||||
fmt.Printf("%s%c", state.wordBuffer, ch)
|
||||
state.wordBuffer = ""
|
||||
@@ -960,11 +961,17 @@ func generate(cmd *cobra.Command, opts runOptions) error {
|
||||
}
|
||||
|
||||
func RunServer(cmd *cobra.Command, _ []string) error {
|
||||
// retrieve the OLLAMA_HOST environment variable
|
||||
ollamaHost, err := api.GetOllamaHost()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := initializeKeypair(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(envconfig.Host.Host, envconfig.Host.Port))
|
||||
ln, err := net.Listen("tcp", net.JoinHostPort(ollamaHost.Host, ollamaHost.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -1023,6 +1030,24 @@ func initializeKeypair() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
//nolint:unused
|
||||
func waitForServer(ctx context.Context, client *api.Client) error {
|
||||
// wait for the server to start
|
||||
timeout := time.After(5 * time.Second)
|
||||
tick := time.Tick(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New("timed out waiting for server to start")
|
||||
case <-tick:
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func checkServerHeartbeat(cmd *cobra.Command, _ []string) error {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"text/template"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
@@ -86,11 +85,11 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err := template.New("").Parse(expectedModelfile)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, opts)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, buf.String(), mf)
|
||||
|
||||
opts.ParentModel = "horseshark"
|
||||
@@ -108,10 +107,10 @@ MESSAGE assistant """Yes it is true, I am half horse, half shark."""
|
||||
`
|
||||
|
||||
tmpl, err = template.New("").Parse(expectedModelfile)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var parentBuf bytes.Buffer
|
||||
err = tmpl.Execute(&parentBuf, opts)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, parentBuf.String(), mf)
|
||||
}
|
||||
|
||||
27
cmd/start.go
27
cmd/start.go
@@ -1,27 +0,0 @@
|
||||
//go:build darwin || windows
|
||||
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func waitForServer(ctx context.Context, client *api.Client) error {
|
||||
// wait for the server to start
|
||||
timeout := time.After(5 * time.Second)
|
||||
tick := time.Tick(500 * time.Millisecond)
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
return errors.New("timed out waiting for server to start")
|
||||
case <-tick:
|
||||
if err := client.Heartbeat(ctx); err == nil {
|
||||
return nil // server has started
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -189,7 +189,7 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
if params.VocabSize > len(v.Tokens) {
|
||||
missingTokens := params.VocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := range missingTokens {
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
v.Scores = append(v.Scores, -1)
|
||||
v.Types = append(v.Types, tokenTypeUserDefined)
|
||||
|
||||
@@ -35,6 +35,7 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
f32s = append(f32s, t...)
|
||||
}
|
||||
|
||||
|
||||
return f32s, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -119,12 +119,11 @@ func llamaRepack(name string, params *Params, data []float32, shape []uint64) ([
|
||||
}
|
||||
|
||||
var heads int
|
||||
switch {
|
||||
case strings.HasSuffix(name, "attn_q.weight"):
|
||||
if strings.HasSuffix(name, "attn_q.weight") {
|
||||
heads = params.AttentionHeads
|
||||
case strings.HasSuffix(name, "attn_k.weight"):
|
||||
} else if strings.HasSuffix(name, "attn_k.weight") {
|
||||
heads = cmp.Or(params.KeyValHeads, params.AttentionHeads)
|
||||
default:
|
||||
} else {
|
||||
return nil, fmt.Errorf("unknown tensor name: %s", name)
|
||||
}
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params)
|
||||
Name: name,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
|
||||
@@ -85,8 +85,11 @@ func parseTokens(dirpath string) (pre string, tokens []Token, merges []string, e
|
||||
|
||||
sha256sum := sha256.New()
|
||||
for _, pt := range t.PreTokenizer.PreTokenizers {
|
||||
if pt.Type == "Split" && pt.Pattern.Regex != "" {
|
||||
sha256sum.Write([]byte(pt.Pattern.Regex))
|
||||
switch pt.Type {
|
||||
case "Split":
|
||||
if pt.Pattern.Regex != "" {
|
||||
sha256sum.Write([]byte(pt.Pattern.Regex))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -88,7 +88,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset, // calculate the offset
|
||||
Shape: shape,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
tensor.WriterTo = torchWriterTo{
|
||||
@@ -104,6 +104,7 @@ func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor,
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
|
||||
}
|
||||
|
||||
func getAltParams(dirpath string) (*Params, error) {
|
||||
|
||||
50
docs/api.md
50
docs/api.md
@@ -12,7 +12,6 @@
|
||||
- [Pull a Model](#pull-a-model)
|
||||
- [Push a Model](#push-a-model)
|
||||
- [Generate Embeddings](#generate-embeddings)
|
||||
- [List Running Models](#list-running-models)
|
||||
|
||||
## Conventions
|
||||
|
||||
@@ -250,7 +249,7 @@ curl http://localhost:11434/api/generate -d '{
|
||||
|
||||
#### Request (Reproducible outputs)
|
||||
|
||||
For reproducible outputs, set `seed` to a number:
|
||||
For reproducible outputs, set `temperature` to 0 and `seed` to a number:
|
||||
|
||||
##### Request
|
||||
|
||||
@@ -259,7 +258,8 @@ curl http://localhost:11434/api/generate -d '{
|
||||
"model": "mistral",
|
||||
"prompt": "Why is the sky blue?",
|
||||
"options": {
|
||||
"seed": 123
|
||||
"seed": 123,
|
||||
"temperature": 0
|
||||
}
|
||||
}'
|
||||
```
|
||||
@@ -1035,47 +1035,3 @@ curl http://localhost:11434/api/embeddings -d '{
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
## List Running Models
|
||||
```shell
|
||||
GET /api/ps
|
||||
```
|
||||
|
||||
List models that are currently loaded into memory.
|
||||
|
||||
#### Examples
|
||||
|
||||
### Request
|
||||
|
||||
```shell
|
||||
curl http://localhost:11434/api/ps
|
||||
```
|
||||
|
||||
#### Response
|
||||
|
||||
A single JSON object will be returned.
|
||||
|
||||
```json
|
||||
{
|
||||
"models": [
|
||||
{
|
||||
"name": "mistral:latest",
|
||||
"model": "mistral:latest",
|
||||
"size": 5137025024,
|
||||
"digest": "2ae6f6dd7a3dd734790bbbf58b8909a606e0e7e97e94b7604e0aa7ae4490e6d8",
|
||||
"details": {
|
||||
"parent_model": "",
|
||||
"format": "gguf",
|
||||
"family": "llama",
|
||||
"families": [
|
||||
"llama"
|
||||
],
|
||||
"parameter_size": "7.2B",
|
||||
"quantization_level": "Q4_0"
|
||||
},
|
||||
"expires_at": "2024-06-04T14:38:31.83753-07:00",
|
||||
"size_vram": 5137025024
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
@@ -8,7 +8,7 @@ Check your compute compatibility to see if your card is supported:
|
||||
| Compute Capability | Family | Cards |
|
||||
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| 9.0 | NVIDIA | `H100` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
|
||||
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080` `RTX 4070 Ti` `RTX 4060 Ti` |
|
||||
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
|
||||
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` |
|
||||
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
|
||||
|
||||
227
docs/import.md
227
docs/import.md
@@ -1,99 +1,170 @@
|
||||
# Import
|
||||
# Import a model
|
||||
|
||||
GGUF models and select Safetensors models can be imported directly into Ollama.
|
||||
This guide walks through importing a GGUF, PyTorch or Safetensors model.
|
||||
|
||||
## Import GGUF
|
||||
## Importing (GGUF)
|
||||
|
||||
A binary GGUF file can be imported directly into Ollama through a Modelfile.
|
||||
### Step 1: Write a `Modelfile`
|
||||
|
||||
```dockerfile
|
||||
FROM /path/to/file.gguf
|
||||
Start by creating a `Modelfile`. This file is the blueprint for your model, specifying weights, parameters, prompt templates and more.
|
||||
|
||||
```
|
||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
```
|
||||
|
||||
## Import Safetensors
|
||||
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
|
||||
|
||||
If the model being imported is one of these architectures, it can be imported directly into Ollama through a Modelfile:
|
||||
|
||||
- LlamaForCausalLM
|
||||
- MistralForCausalLM
|
||||
- GemmaForCausalLM
|
||||
|
||||
```dockerfile
|
||||
FROM /path/to/safetensors/directory
|
||||
```
|
||||
FROM ./mistral-7b-v0.1.Q4_0.gguf
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
For architectures not directly convertable by Ollama, see llama.cpp's [guide](https://github.com/ggerganov/llama.cpp/blob/master/README.md#prepare-and-quantize) on conversion. After conversion, see [Import GGUF](#import-gguf).
|
||||
### Step 2: Create the Ollama model
|
||||
|
||||
## Automatic Quantization
|
||||
Finally, create a model from your `Modelfile`:
|
||||
|
||||
> [!NOTE]
|
||||
> Automatic quantization requires v0.1.35 or higher.
|
||||
|
||||
Ollama is capable of quantizing FP16 or FP32 models to any of the supported quantizations with the `-q/--quantize` flag in `ollama create`.
|
||||
|
||||
```dockerfile
|
||||
FROM /path/to/my/gemma/f16/model
|
||||
```
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
### Step 3: Run your model
|
||||
|
||||
Next, test the model with `ollama run`:
|
||||
|
||||
```
|
||||
ollama run example "What is your favourite condiment?"
|
||||
```
|
||||
|
||||
## Importing (PyTorch & Safetensors)
|
||||
|
||||
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
|
||||
|
||||
### Setup
|
||||
|
||||
First, clone the `ollama/ollama` repo:
|
||||
|
||||
```
|
||||
git clone git@github.com:ollama/ollama.git ollama
|
||||
cd ollama
|
||||
```
|
||||
|
||||
and then fetch its `llama.cpp` submodule:
|
||||
|
||||
```shell
|
||||
$ ollama create -q Q4_K_M mymodel
|
||||
transferring model data
|
||||
quantizing F16 model to Q4_K_M
|
||||
creating new layer sha256:735e246cc1abfd06e9cdcf95504d6789a6cd1ad7577108a70d9902fef503c1bd
|
||||
creating new layer sha256:0853f0ad24e5865173bbf9ffcc7b0f5d56b66fd690ab1009867e45e7d2c4db0f
|
||||
writing manifest
|
||||
success
|
||||
git submodule init
|
||||
git submodule update llm/llama.cpp
|
||||
```
|
||||
|
||||
### Supported Quantizations
|
||||
Next, install the Python dependencies:
|
||||
|
||||
<details>
|
||||
<summary>Legacy Quantization</summary>
|
||||
|
||||
- `Q4_0`
|
||||
- `Q4_1`
|
||||
- `Q5_0`
|
||||
- `Q5_1`
|
||||
- `Q8_0`
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary>K-means Quantization</summary>`
|
||||
|
||||
- `Q3_K_S`
|
||||
- `Q3_K_M`
|
||||
- `Q3_K_L`
|
||||
- `Q4_K_S`
|
||||
- `Q4_K_M`
|
||||
- `Q5_K_S`
|
||||
- `Q5_K_M`
|
||||
- `Q6_K`
|
||||
|
||||
</details>
|
||||
|
||||
> [!NOTE]
|
||||
> Activation-aware Weight Quantization (i.e. IQ) are not currently supported for automatic quantization however you can still import the quantized model into Ollama, see [Import GGUF](#import-gguf).
|
||||
|
||||
## Template Detection
|
||||
|
||||
> [!NOTE]
|
||||
> Template detection requires v0.1.42 or higher.
|
||||
|
||||
Ollama uses model metadata, specifically `tokenizer.chat_template`, to automatically create a template appropriate for the model you're importing.
|
||||
|
||||
```dockerfile
|
||||
FROM /path/to/my/gemma/model
|
||||
```
|
||||
python3 -m venv llm/llama.cpp/.venv
|
||||
source llm/llama.cpp/.venv/bin/activate
|
||||
pip install -r llm/llama.cpp/requirements.txt
|
||||
```
|
||||
|
||||
```shell
|
||||
$ ollama create mymodel
|
||||
transferring model data
|
||||
using autodetected template gemma-instruct
|
||||
creating new layer sha256:baa2a0edc27d19cc6b7537578a9a7ba1a4e3214dc185ed5ae43692b319af7b84
|
||||
creating new layer sha256:ba66c3309914dbef07e5149a648fd1877f030d337a4f240d444ea335008943cb
|
||||
writing manifest
|
||||
success
|
||||
Then build the `quantize` tool:
|
||||
|
||||
```
|
||||
make -C llm/llama.cpp quantize
|
||||
```
|
||||
|
||||
Defining a template in the Modelfile will disable this feature which may be useful if you want to use a different template than the autodetected one.
|
||||
### Clone the HuggingFace repository (optional)
|
||||
|
||||
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
|
||||
|
||||
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
|
||||
|
||||
```
|
||||
git lfs install
|
||||
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
|
||||
```
|
||||
|
||||
### Convert the model
|
||||
|
||||
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
|
||||
|
||||
```
|
||||
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
|
||||
```
|
||||
|
||||
### Quantize the model
|
||||
|
||||
```
|
||||
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
|
||||
```
|
||||
|
||||
### Step 3: Write a `Modelfile`
|
||||
|
||||
Next, create a `Modelfile` for your model:
|
||||
|
||||
```
|
||||
FROM quantized.bin
|
||||
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
|
||||
```
|
||||
|
||||
### Step 4: Create the Ollama model
|
||||
|
||||
Finally, create a model from your `Modelfile`:
|
||||
|
||||
```
|
||||
ollama create example -f Modelfile
|
||||
```
|
||||
|
||||
### Step 5: Run your model
|
||||
|
||||
Next, test the model with `ollama run`:
|
||||
|
||||
```
|
||||
ollama run example "What is your favourite condiment?"
|
||||
```
|
||||
|
||||
## Publishing your model (optional – early alpha)
|
||||
|
||||
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
|
||||
|
||||
1. Create [an account](https://ollama.com/signup)
|
||||
2. Copy your Ollama public key:
|
||||
- macOS: `cat ~/.ollama/id_ed25519.pub | pbcopy`
|
||||
- Windows: `type %USERPROFILE%\.ollama\id_ed25519.pub`
|
||||
- Linux: `cat /usr/share/ollama/.ollama/id_ed25519.pub`
|
||||
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
|
||||
|
||||
Next, copy your model to your username's namespace:
|
||||
|
||||
```
|
||||
ollama cp example <your username>/example
|
||||
```
|
||||
|
||||
> Note: model names may only contain lowercase letters, digits, and the characters `.`, `-`, and `_`.
|
||||
|
||||
Then push the model:
|
||||
|
||||
```
|
||||
ollama push <your username>/example
|
||||
```
|
||||
|
||||
After publishing, your model will be available at `https://ollama.com/<your username>/example`.
|
||||
|
||||
## Quantization reference
|
||||
|
||||
The quantization options are as follow (from highest highest to lowest levels of quantization). Note: some architectures such as Falcon do not support K quants.
|
||||
|
||||
- `q2_K`
|
||||
- `q3_K`
|
||||
- `q3_K_S`
|
||||
- `q3_K_M`
|
||||
- `q3_K_L`
|
||||
- `q4_0` (recommended)
|
||||
- `q4_1`
|
||||
- `q4_K`
|
||||
- `q4_K_S`
|
||||
- `q4_K_M`
|
||||
- `q5_0`
|
||||
- `q5_1`
|
||||
- `q5_K`
|
||||
- `q5_K_S`
|
||||
- `q5_K_M`
|
||||
- `q6_K`
|
||||
- `q8_0`
|
||||
- `f16`
|
||||
|
||||
@@ -100,16 +100,6 @@ sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
|
||||
sudo chmod +x /usr/bin/ollama
|
||||
```
|
||||
|
||||
## Installing specific versions
|
||||
|
||||
Use `OLLAMA_VERSION` environment variable with the install script to install a specific version of Ollama, including pre-releases. You can find the version numbers in the [releases page](https://github.com/ollama/ollama/releases).
|
||||
|
||||
For example:
|
||||
|
||||
```
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION=0.1.32 sh
|
||||
```
|
||||
|
||||
## Viewing logs
|
||||
|
||||
To view logs of Ollama running as a startup service, run:
|
||||
|
||||
@@ -104,6 +104,7 @@ curl http://localhost:11434/v1/chat/completions \
|
||||
|
||||
#### Notes
|
||||
|
||||
- Setting `seed` will always set `temperature` to `0`
|
||||
- `finish_reason` will always be `stop`
|
||||
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
|
||||
|
||||
|
||||
@@ -45,7 +45,7 @@ all_splits = text_splitter.split_documents(data)
|
||||
```
|
||||
|
||||
It's split up, but we have to find the relevant splits and then submit those to the model. We can do this by creating embeddings and storing them in a vector database. We can use Ollama directly to instantiate an embedding model. We will use ChromaDB in this example for a vector database. `pip install chromadb`
|
||||
We also need to pull embedding model: `ollama pull nomic-embed-text`
|
||||
|
||||
```python
|
||||
from langchain.embeddings import OllamaEmbeddings
|
||||
from langchain.vectorstores import Chroma
|
||||
@@ -68,8 +68,7 @@ The next thing is to send the question and the relevant parts of the docs to the
|
||||
```python
|
||||
from langchain.chains import RetrievalQA
|
||||
qachain=RetrievalQA.from_chain_type(ollama, retriever=vectorstore.as_retriever())
|
||||
res = qachain.invoke({"query": question})
|
||||
print(res['result'])
|
||||
qachain.invoke({"query": question})
|
||||
```
|
||||
|
||||
The answer received from this chain was:
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
@@ -12,18 +10,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type OllamaHost struct {
|
||||
Scheme string
|
||||
Host string
|
||||
Port string
|
||||
}
|
||||
|
||||
func (o OllamaHost) String() string {
|
||||
return fmt.Sprintf("%s://%s:%s", o.Scheme, o.Host, o.Port)
|
||||
}
|
||||
|
||||
var ErrInvalidHostPort = errors.New("invalid port specified in OLLAMA_HOST")
|
||||
|
||||
var (
|
||||
// Set via OLLAMA_ORIGINS in the environment
|
||||
AllowOrigins []string
|
||||
@@ -31,8 +17,6 @@ var (
|
||||
Debug bool
|
||||
// Experimental flash attention
|
||||
FlashAttention bool
|
||||
// Set via OLLAMA_HOST in the environment
|
||||
Host *OllamaHost
|
||||
// Set via OLLAMA_KEEP_ALIVE in the environment
|
||||
KeepAlive string
|
||||
// Set via OLLAMA_LLM_LIBRARY in the environment
|
||||
@@ -41,8 +25,6 @@ var (
|
||||
MaxRunners int
|
||||
// Set via OLLAMA_MAX_QUEUE in the environment
|
||||
MaxQueuedRequests int
|
||||
// Set via OLLAMA_MODELS in the environment
|
||||
ModelsDir string
|
||||
// Set via OLLAMA_MAX_VRAM in the environment
|
||||
MaxVRAM uint64
|
||||
// Set via OLLAMA_NOHISTORY in the environment
|
||||
@@ -53,21 +35,8 @@ var (
|
||||
NumParallel int
|
||||
// Set via OLLAMA_RUNNERS_DIR in the environment
|
||||
RunnersDir string
|
||||
// Set via OLLAMA_SCHED_SPREAD in the environment
|
||||
SchedSpread bool
|
||||
// Set via OLLAMA_TMPDIR in the environment
|
||||
TmpDir string
|
||||
|
||||
// Set via CUDA_VISIBLE_DEVICES in the environment
|
||||
CudaVisibleDevices string
|
||||
// Set via HIP_VISIBLE_DEVICES in the environment
|
||||
HipVisibleDevices string
|
||||
// Set via ROCR_VISIBLE_DEVICES in the environment
|
||||
RocrVisibleDevices string
|
||||
// Set via GPU_DEVICE_ORDINAL in the environment
|
||||
GpuDeviceOrdinal string
|
||||
// Set via HSA_OVERRIDE_GFX_VERSION in the environment
|
||||
HsaOverrideGfxVersion string
|
||||
)
|
||||
|
||||
type EnvVar struct {
|
||||
@@ -77,32 +46,23 @@ type EnvVar struct {
|
||||
}
|
||||
|
||||
func AsMap() map[string]EnvVar {
|
||||
ret := map[string]EnvVar{
|
||||
return map[string]EnvVar{
|
||||
"OLLAMA_DEBUG": {"OLLAMA_DEBUG", Debug, "Show additional debug information (e.g. OLLAMA_DEBUG=1)"},
|
||||
"OLLAMA_FLASH_ATTENTION": {"OLLAMA_FLASH_ATTENTION", FlashAttention, "Enabled flash attention"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", Host, "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_HOST": {"OLLAMA_HOST", "", "IP Address for the ollama server (default 127.0.0.1:11434)"},
|
||||
"OLLAMA_KEEP_ALIVE": {"OLLAMA_KEEP_ALIVE", KeepAlive, "The duration that models stay loaded in memory (default \"5m\")"},
|
||||
"OLLAMA_LLM_LIBRARY": {"OLLAMA_LLM_LIBRARY", LLMLibrary, "Set LLM library to bypass autodetection"},
|
||||
"OLLAMA_MAX_LOADED_MODELS": {"OLLAMA_MAX_LOADED_MODELS", MaxRunners, "Maximum number of loaded models (default 1)"},
|
||||
"OLLAMA_MAX_QUEUE": {"OLLAMA_MAX_QUEUE", MaxQueuedRequests, "Maximum number of queued requests"},
|
||||
"OLLAMA_MAX_VRAM": {"OLLAMA_MAX_VRAM", MaxVRAM, "Maximum VRAM"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", ModelsDir, "The path to the models directory"},
|
||||
"OLLAMA_MODELS": {"OLLAMA_MODELS", "", "The path to the models directory"},
|
||||
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory, "Do not preserve readline history"},
|
||||
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune, "Do not prune model blobs on startup"},
|
||||
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel, "Maximum number of parallel requests (default 1)"},
|
||||
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowOrigins, "A comma separated list of allowed origins"},
|
||||
"OLLAMA_RUNNERS_DIR": {"OLLAMA_RUNNERS_DIR", RunnersDir, "Location for runners"},
|
||||
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread, "Always schedule model across all GPUs"},
|
||||
"OLLAMA_TMPDIR": {"OLLAMA_TMPDIR", TmpDir, "Location for temporary files"},
|
||||
}
|
||||
if runtime.GOOS != "darwin" {
|
||||
ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices, "Set which NVIDIA devices are visible"}
|
||||
ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices, "Set which AMD devices are visible"}
|
||||
ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices, "Set which AMD devices are visible"}
|
||||
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal, "Set which AMD devices are visible"}
|
||||
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion, "Override the gfx used for all detected AMD GPUs"}
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func Values() map[string]string {
|
||||
@@ -166,7 +126,7 @@ func LoadConfig() {
|
||||
var paths []string
|
||||
for _, root := range []string{filepath.Dir(appExe), cwd} {
|
||||
paths = append(paths,
|
||||
root,
|
||||
filepath.Join(root),
|
||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
||||
)
|
||||
@@ -213,15 +173,6 @@ func LoadConfig() {
|
||||
NoHistory = true
|
||||
}
|
||||
|
||||
if spread := clean("OLLAMA_SCHED_SPREAD"); spread != "" {
|
||||
s, err := strconv.ParseBool(spread)
|
||||
if err == nil {
|
||||
SchedSpread = s
|
||||
} else {
|
||||
SchedSpread = true
|
||||
}
|
||||
}
|
||||
|
||||
if noprune := clean("OLLAMA_NOPRUNE"); noprune != "" {
|
||||
NoPrune = true
|
||||
}
|
||||
@@ -233,17 +184,11 @@ func LoadConfig() {
|
||||
AllowOrigins = append(AllowOrigins,
|
||||
fmt.Sprintf("http://%s", allowOrigin),
|
||||
fmt.Sprintf("https://%s", allowOrigin),
|
||||
fmt.Sprintf("http://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||
fmt.Sprintf("https://%s", net.JoinHostPort(allowOrigin, "*")),
|
||||
fmt.Sprintf("http://%s:*", allowOrigin),
|
||||
fmt.Sprintf("https://%s:*", allowOrigin),
|
||||
)
|
||||
}
|
||||
|
||||
AllowOrigins = append(AllowOrigins,
|
||||
"app://*",
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
)
|
||||
|
||||
maxRunners := clean("OLLAMA_MAX_LOADED_MODELS")
|
||||
if maxRunners != "" {
|
||||
m, err := strconv.Atoi(maxRunners)
|
||||
@@ -264,76 +209,4 @@ func LoadConfig() {
|
||||
}
|
||||
|
||||
KeepAlive = clean("OLLAMA_KEEP_ALIVE")
|
||||
|
||||
var err error
|
||||
ModelsDir, err = getModelsDir()
|
||||
if err != nil {
|
||||
slog.Error("invalid setting", "OLLAMA_MODELS", ModelsDir, "error", err)
|
||||
}
|
||||
|
||||
Host, err = getOllamaHost()
|
||||
if err != nil {
|
||||
slog.Error("invalid setting", "OLLAMA_HOST", Host, "error", err, "using default port", Host.Port)
|
||||
}
|
||||
|
||||
CudaVisibleDevices = clean("CUDA_VISIBLE_DEVICES")
|
||||
HipVisibleDevices = clean("HIP_VISIBLE_DEVICES")
|
||||
RocrVisibleDevices = clean("ROCR_VISIBLE_DEVICES")
|
||||
GpuDeviceOrdinal = clean("GPU_DEVICE_ORDINAL")
|
||||
HsaOverrideGfxVersion = clean("HSA_OVERRIDE_GFX_VERSION")
|
||||
}
|
||||
|
||||
func getModelsDir() (string, error) {
|
||||
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
|
||||
return models, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models"), nil
|
||||
}
|
||||
|
||||
func getOllamaHost() (*OllamaHost, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
hostVar := os.Getenv("OLLAMA_HOST")
|
||||
hostVar = strings.TrimSpace(strings.Trim(strings.TrimSpace(hostVar), "\"'"))
|
||||
|
||||
scheme, hostport, ok := strings.Cut(hostVar, "://")
|
||||
switch {
|
||||
case !ok:
|
||||
scheme, hostport = "http", hostVar
|
||||
case scheme == "http":
|
||||
defaultPort = "80"
|
||||
case scheme == "https":
|
||||
defaultPort = "443"
|
||||
}
|
||||
|
||||
// trim trailing slashes
|
||||
hostport = strings.TrimRight(hostport, "/")
|
||||
|
||||
host, port, err := net.SplitHostPort(hostport)
|
||||
if err != nil {
|
||||
host, port = "127.0.0.1", defaultPort
|
||||
if ip := net.ParseIP(strings.Trim(hostport, "[]")); ip != nil {
|
||||
host = ip.String()
|
||||
} else if hostport != "" {
|
||||
host = hostport
|
||||
}
|
||||
}
|
||||
|
||||
if portNum, err := strconv.ParseInt(port, 10, 32); err != nil || portNum > 65535 || portNum < 0 {
|
||||
return &OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: defaultPort,
|
||||
}, ErrInvalidHostPort
|
||||
}
|
||||
|
||||
return &OllamaHost{
|
||||
Scheme: scheme,
|
||||
Host: host,
|
||||
Port: port,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
package envconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@@ -24,48 +21,3 @@ func TestConfig(t *testing.T) {
|
||||
LoadConfig()
|
||||
require.True(t, FlashAttention)
|
||||
}
|
||||
|
||||
func TestClientFromEnvironment(t *testing.T) {
|
||||
type testCase struct {
|
||||
value string
|
||||
expect string
|
||||
err error
|
||||
}
|
||||
|
||||
hostTestCases := map[string]*testCase{
|
||||
"empty": {value: "", expect: "127.0.0.1:11434"},
|
||||
"only address": {value: "1.2.3.4", expect: "1.2.3.4:11434"},
|
||||
"only port": {value: ":1234", expect: ":1234"},
|
||||
"address and port": {value: "1.2.3.4:1234", expect: "1.2.3.4:1234"},
|
||||
"hostname": {value: "example.com", expect: "example.com:11434"},
|
||||
"hostname and port": {value: "example.com:1234", expect: "example.com:1234"},
|
||||
"zero port": {value: ":0", expect: ":0"},
|
||||
"too large port": {value: ":66000", err: ErrInvalidHostPort},
|
||||
"too small port": {value: ":-1", err: ErrInvalidHostPort},
|
||||
"ipv6 localhost": {value: "[::1]", expect: "[::1]:11434"},
|
||||
"ipv6 world open": {value: "[::]", expect: "[::]:11434"},
|
||||
"ipv6 no brackets": {value: "::1", expect: "[::1]:11434"},
|
||||
"ipv6 + port": {value: "[::1]:1337", expect: "[::1]:1337"},
|
||||
"extra space": {value: " 1.2.3.4 ", expect: "1.2.3.4:11434"},
|
||||
"extra quotes": {value: "\"1.2.3.4\"", expect: "1.2.3.4:11434"},
|
||||
"extra space+quotes": {value: " \" 1.2.3.4 \" ", expect: "1.2.3.4:11434"},
|
||||
"extra single quotes": {value: "'1.2.3.4'", expect: "1.2.3.4:11434"},
|
||||
}
|
||||
|
||||
for k, v := range hostTestCases {
|
||||
t.Run(k, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_HOST", v.value)
|
||||
LoadConfig()
|
||||
|
||||
oh, err := getOllamaHost()
|
||||
if err != v.err {
|
||||
t.Fatalf("expected %s, got %s", v.err, err)
|
||||
}
|
||||
|
||||
if err == nil {
|
||||
host := net.JoinHostPort(oh.Host, oh.Port)
|
||||
assert.Equal(t, v.expect, host, fmt.Sprintf("%s: expected %s, got %s", k, v.expect, host))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,21 +77,13 @@ LOADER_MAPPING = {
|
||||
|
||||
|
||||
def load_single_document(file_path: str) -> List[Document]:
|
||||
if os.path.getsize(file_path) != 0:
|
||||
filename, ext = os.path.splitext(file_path)
|
||||
if ext in LOADER_MAPPING:
|
||||
loader_class, loader_args = LOADER_MAPPING[ext]
|
||||
try:
|
||||
loader = loader_class(file_path, **loader_args)
|
||||
if loader:
|
||||
return loader.load()
|
||||
except:
|
||||
print(f"Corrupted file {file_path}. Ignoring it.")
|
||||
else:
|
||||
print(f"Unsupported file {file_path}. Ignoring it.")
|
||||
else:
|
||||
print(f"Empty file {file_path}. Ignoring it.")
|
||||
ext = "." + file_path.rsplit(".", 1)[-1]
|
||||
if ext in LOADER_MAPPING:
|
||||
loader_class, loader_args = LOADER_MAPPING[ext]
|
||||
loader = loader_class(file_path, **loader_args)
|
||||
return loader.load()
|
||||
|
||||
raise ValueError(f"Unsupported file extension '{ext}'")
|
||||
|
||||
def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Document]:
|
||||
"""
|
||||
@@ -108,8 +100,7 @@ def load_documents(source_dir: str, ignored_files: List[str] = []) -> List[Docum
|
||||
results = []
|
||||
with tqdm(total=len(filtered_files), desc='Loading new documents', ncols=80) as pbar:
|
||||
for i, docs in enumerate(pool.imap_unordered(load_single_document, filtered_files)):
|
||||
if docs:
|
||||
results.extend(docs)
|
||||
results.extend(docs)
|
||||
pbar.update()
|
||||
|
||||
return results
|
||||
|
||||
@@ -11,5 +11,4 @@ tabulate==0.9.0
|
||||
pandoc==2.3
|
||||
pypandoc==1.11
|
||||
tqdm==4.66.1
|
||||
sentence_transformers==2.2.2
|
||||
numpy>=1.22.2 # not directly required, pinned by Snyk to avoid a vulnerability
|
||||
sentence_transformers==2.2.2
|
||||
@@ -2,32 +2,41 @@ package format
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
)
|
||||
|
||||
const (
|
||||
Thousand = 1000
|
||||
Million = Thousand * 1000
|
||||
Billion = Million * 1000
|
||||
Trillion = Billion * 1000
|
||||
)
|
||||
|
||||
func HumanNumber(b uint64) string {
|
||||
switch {
|
||||
case b >= Trillion:
|
||||
number := float64(b) / Trillion
|
||||
return fmt.Sprintf("%sT", DecimalPlace(number))
|
||||
case b >= Billion:
|
||||
number := float64(b) / Billion
|
||||
if number == math.Floor(number) {
|
||||
return fmt.Sprintf("%.0fB", number) // no decimals if whole number
|
||||
}
|
||||
return fmt.Sprintf("%.1fB", number) // one decimal if not a whole number
|
||||
return fmt.Sprintf("%sB", DecimalPlace(number))
|
||||
case b >= Million:
|
||||
number := float64(b) / Million
|
||||
if number == math.Floor(number) {
|
||||
return fmt.Sprintf("%.0fM", number) // no decimals if whole number
|
||||
}
|
||||
return fmt.Sprintf("%.2fM", number) // two decimals if not a whole number
|
||||
return fmt.Sprintf("%sM", DecimalPlace(number))
|
||||
case b >= Thousand:
|
||||
return fmt.Sprintf("%.0fK", float64(b)/Thousand)
|
||||
number := float64(b) / Thousand
|
||||
return fmt.Sprintf("%sK", DecimalPlace(number))
|
||||
default:
|
||||
return fmt.Sprintf("%d", b)
|
||||
}
|
||||
}
|
||||
|
||||
func DecimalPlace(number float64) string {
|
||||
switch {
|
||||
case number >= 100:
|
||||
return fmt.Sprintf("%.0f", number)
|
||||
case number >= 10:
|
||||
return fmt.Sprintf("%.1f", number)
|
||||
default:
|
||||
return fmt.Sprintf("%.2f", number)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
)
|
||||
|
||||
func TestHumanNumber(t *testing.T) {
|
||||
|
||||
type testCase struct {
|
||||
input uint64
|
||||
expected string
|
||||
@@ -12,14 +13,15 @@ func TestHumanNumber(t *testing.T) {
|
||||
|
||||
testCases := []testCase{
|
||||
{0, "0"},
|
||||
{1000000, "1M"},
|
||||
{1000000, "1.00M"},
|
||||
{125000000, "125M"},
|
||||
{500500000, "500.50M"},
|
||||
{500550000, "500.55M"},
|
||||
{1000000000, "1B"},
|
||||
{2800000000, "2.8B"},
|
||||
{2850000000, "2.9B"},
|
||||
{1000000000000, "1000B"},
|
||||
{500500000, "500M"},
|
||||
{500550000, "501M"},
|
||||
{1000000000, "1.00B"},
|
||||
{2800000000, "2.80B"},
|
||||
{2850000000, "2.85B"},
|
||||
{28550000000, "28.6B"},
|
||||
{1000000000000, "1.00T"},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
|
||||
1
go.mod
1
go.mod
@@ -16,7 +16,6 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/agnivade/levenshtein v1.1.1
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
|
||||
6
go.sum
6
go.sum
@@ -4,14 +4,10 @@ dmitri.shuralyov.com/gpu/mtl v0.0.0-20190408044501-666a987793e9/go.mod h1:H6x//7
|
||||
gioui.org v0.0.0-20210308172011-57750fc8a0a6/go.mod h1:RSH6KIUZ0p2xy5zHDxgAM4zumjgTw83q2ge/PI+yyw8=
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
|
||||
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
|
||||
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
|
||||
github.com/ajstarks/svgo v0.0.0-20180226025133-644b8db467af/go.mod h1:K08gAheRH3/J6wwsYMMT4xOr94bZjxIelGM0+d/wbFw=
|
||||
github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY=
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40 h1:q4dksr6ICHXqG5hm0ZW5IHyeEJXoIJSOZeBLmWPNeIQ=
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20211112161151-bc219186db40/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
|
||||
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
|
||||
github.com/boombuler/barcode v1.0.0/go.mod h1:paBWMcWSl3LHKBqUq+rly7CNSldXjb2rDl3JlRe0mD8=
|
||||
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
|
||||
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
|
||||
@@ -40,8 +36,6 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48 h1:fRzb/w+pyskVMQ+UbP35JkH8yB7MYb4q/qhBarqZE6g=
|
||||
github.com/dgryski/trifles v0.0.0-20200323201526-dd97f9abfb48/go.mod h1:if7Fbed8SFyPtHLHbg49SI7NAdJiC5WIA09pe59rfAA=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
|
||||
222
gpu/amd_linux.go
222
gpu/amd_linux.go
@@ -13,7 +13,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
@@ -26,16 +25,7 @@ const (
|
||||
|
||||
// Prefix with the node dir
|
||||
GPUTotalMemoryFileGlob = "mem_banks/*/properties" // size_in_bytes line
|
||||
|
||||
// Direct Rendering Manager sysfs location
|
||||
DRMDeviceDirGlob = "/sys/class/drm/card*/device"
|
||||
DRMTotalMemoryFile = "mem_info_vram_total"
|
||||
DRMUsedMemoryFile = "mem_info_vram_used"
|
||||
|
||||
// In hex; properties file is in decimal
|
||||
DRMUniqueIDFile = "unique_id"
|
||||
DRMVendorFile = "vendor"
|
||||
DRMDeviceFile = "device"
|
||||
GPUUsedMemoryFileGlob = "mem_banks/*/used_memory"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,8 +35,8 @@ var (
|
||||
)
|
||||
|
||||
// Gather GPU information from the amdgpu driver if any supported GPUs are detected
|
||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
resp := []RocmGPUInfo{}
|
||||
func AMDGetGPUInfo() []GpuInfo {
|
||||
resp := []GpuInfo{}
|
||||
if !AMDDetected() {
|
||||
return resp
|
||||
}
|
||||
@@ -60,9 +50,9 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
|
||||
// Determine if the user has already pre-selected which GPUs to look at, then ignore the others
|
||||
var visibleDevices []string
|
||||
hipVD := envconfig.HipVisibleDevices // zero based index only
|
||||
rocrVD := envconfig.RocrVisibleDevices // zero based index or UUID, but consumer cards seem to not support UUID
|
||||
gpuDO := envconfig.GpuDeviceOrdinal // zero based index
|
||||
hipVD := os.Getenv("HIP_VISIBLE_DEVICES") // zero based index only
|
||||
rocrVD := os.Getenv("ROCR_VISIBLE_DEVICES") // zero based index or UUID, but consumer cards seem to not support UUID
|
||||
gpuDO := os.Getenv("GPU_DEVICE_ORDINAL") // zero based index
|
||||
switch {
|
||||
// TODO is this priorty order right?
|
||||
case hipVD != "":
|
||||
@@ -75,7 +65,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
visibleDevices = strings.Split(gpuDO, ",")
|
||||
}
|
||||
|
||||
gfxOverride := envconfig.HsaOverrideGfxVersion
|
||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||
var supported []string
|
||||
libDir := ""
|
||||
|
||||
@@ -100,7 +90,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
scanner := bufio.NewScanner(fp)
|
||||
isCPU := false
|
||||
var major, minor, patch uint64
|
||||
var vendor, device, uniqueID uint64
|
||||
var vendor, device uint64
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
// Note: we could also use "cpu_cores_count X" where X is greater than zero to detect CPUs
|
||||
@@ -131,43 +121,30 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
} else if strings.HasPrefix(line, "vendor_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "vendor_id", line)
|
||||
slog.Debug("malformed vendor_id", "vendor_id", line)
|
||||
continue
|
||||
}
|
||||
vendor, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
vendor, err = strconv.ParseUint(ver[1], 10, 32)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "vendor_id", line, "error", err)
|
||||
slog.Debug("malformed vendor_id" + line)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "device_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "device_id", line)
|
||||
slog.Debug("malformed device_id", "device_id", line)
|
||||
continue
|
||||
}
|
||||
device, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
device, err = strconv.ParseUint(ver[1], 10, 32)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "device_id", line, "error", err)
|
||||
}
|
||||
} else if strings.HasPrefix(line, "unique_id") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Debug("malformed", "unique_id", line)
|
||||
continue
|
||||
}
|
||||
uniqueID, err = strconv.ParseUint(ver[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("malformed", "unique_id", line, "error", err)
|
||||
slog.Debug("malformed device_id" + line)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO - any other properties we want to extract and record?
|
||||
// vendor_id + device_id -> pci lookup for "Name"
|
||||
// Other metrics that may help us understand relative performance between multiple GPUs
|
||||
}
|
||||
|
||||
// Note: while ./mem_banks/*/used_memory exists, it doesn't appear to take other VRAM consumers
|
||||
// into consideration, so we instead map the device over to the DRM driver sysfs nodes which
|
||||
// do reliably report VRAM usage.
|
||||
|
||||
if isCPU {
|
||||
cpuCount++
|
||||
continue
|
||||
@@ -179,7 +156,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
// Shouldn't happen, but just in case...
|
||||
if gpuID < 0 {
|
||||
slog.Error("unexpected amdgpu sysfs data resulted in negative GPU ID, please set OLLAMA_DEBUG=1 and report an issue")
|
||||
return nil
|
||||
return []GpuInfo{}
|
||||
}
|
||||
|
||||
if int(major) < RocmComputeMin {
|
||||
@@ -190,68 +167,65 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
// Look up the memory for the current node
|
||||
totalMemory := uint64(0)
|
||||
usedMemory := uint64(0)
|
||||
var usedFile string
|
||||
mapping := []struct {
|
||||
id uint64
|
||||
filename string
|
||||
}{
|
||||
{vendor, DRMVendorFile},
|
||||
{device, DRMDeviceFile},
|
||||
{uniqueID, DRMUniqueIDFile}, // Not all devices will report this
|
||||
propGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUTotalMemoryFileGlob)
|
||||
propFiles, err := filepath.Glob(propGlob)
|
||||
if err != nil {
|
||||
slog.Warn("error looking up total GPU memory", "glob", propGlob, "error", err)
|
||||
}
|
||||
slog.Debug("mapping amdgpu to drm sysfs nodes", "amdgpu", match, "vendor", vendor, "device", device, "unique_id", uniqueID)
|
||||
// Map over to DRM location to find the total/free memory
|
||||
drmMatches, _ := filepath.Glob(DRMDeviceDirGlob)
|
||||
for _, devDir := range drmMatches {
|
||||
matched := true
|
||||
for _, m := range mapping {
|
||||
if m.id == 0 {
|
||||
// Null ID means it didn't populate, so we can't use it to match
|
||||
continue
|
||||
}
|
||||
filename := filepath.Join(devDir, m.filename)
|
||||
buf, err := os.ReadFile(filename)
|
||||
if err != nil {
|
||||
slog.Debug("failed to read sysfs node", "file", filename, "error", err)
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
// values here are in hex, strip off the lead 0x and parse so we can compare the numeric (decimal) values in amdgpu
|
||||
cmp, err := strconv.ParseUint(strings.TrimPrefix(strings.TrimSpace(string(buf)), "0x"), 16, 64)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", filename, "error", err)
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
if cmp != m.id {
|
||||
matched = false
|
||||
break
|
||||
}
|
||||
}
|
||||
if !matched {
|
||||
// 1 or more memory banks - sum the values of all of them
|
||||
for _, propFile := range propFiles {
|
||||
fp, err := os.Open(propFile)
|
||||
if err != nil {
|
||||
slog.Warn("failed to open sysfs node", "file", propFile, "erroir", err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Found the matching DRM directory
|
||||
slog.Debug("matched", "amdgpu", match, "drm", devDir)
|
||||
totalFile := filepath.Join(devDir, DRMTotalMemoryFile)
|
||||
buf, err := os.ReadFile(totalFile)
|
||||
if err != nil {
|
||||
slog.Debug("failed to read sysfs node", "file", totalFile, "error", err)
|
||||
break
|
||||
defer fp.Close()
|
||||
scanner := bufio.NewScanner(fp)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "size_in_bytes") {
|
||||
ver := strings.Fields(line)
|
||||
if len(ver) != 2 {
|
||||
slog.Warn("malformed " + line)
|
||||
continue
|
||||
}
|
||||
bankSizeInBytes, err := strconv.ParseUint(ver[1], 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("malformed int " + line)
|
||||
continue
|
||||
}
|
||||
totalMemory += bankSizeInBytes
|
||||
}
|
||||
}
|
||||
totalMemory, err = strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
|
||||
}
|
||||
if totalMemory == 0 {
|
||||
slog.Warn("amdgpu reports zero total memory", "gpu", gpuID)
|
||||
continue
|
||||
}
|
||||
usedGlob := filepath.Join(AMDNodesSysfsDir, strconv.Itoa(nodeID), GPUUsedMemoryFileGlob)
|
||||
usedFiles, err := filepath.Glob(usedGlob)
|
||||
if err != nil {
|
||||
slog.Warn("error looking up used GPU memory", "glob", usedGlob, "error", err)
|
||||
continue
|
||||
}
|
||||
for _, usedFile := range usedFiles {
|
||||
fp, err := os.Open(usedFile)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", totalFile, "error", err)
|
||||
break
|
||||
slog.Warn("failed to open sysfs node", "file", usedFile, "error", err)
|
||||
continue
|
||||
}
|
||||
|
||||
usedFile = filepath.Join(devDir, DRMUsedMemoryFile)
|
||||
usedMemory, err = getFreeMemory(usedFile)
|
||||
defer fp.Close()
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Debug("failed to update used memory", "error", err)
|
||||
slog.Warn("failed to read sysfs node", "file", usedFile, "error", err)
|
||||
continue
|
||||
}
|
||||
break
|
||||
used, err := strconv.ParseUint(strings.TrimSpace(string(data)), 10, 64)
|
||||
if err != nil {
|
||||
slog.Warn("malformed used memory", "data", string(data), "error", err)
|
||||
continue
|
||||
}
|
||||
usedMemory += used
|
||||
}
|
||||
|
||||
// iGPU detection, remove this check once we can support an iGPU variant of the rocm library
|
||||
@@ -267,21 +241,18 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory))
|
||||
gpuInfo := RocmGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
ID: strconv.Itoa(gpuID),
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: (totalMemory - usedMemory),
|
||||
},
|
||||
usedFilepath: usedFile,
|
||||
ID: fmt.Sprintf("%d", gpuID),
|
||||
Name: name,
|
||||
Compute: fmt.Sprintf("gfx%d%x%x", major, minor, patch),
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
DriverMajor: driverMajor,
|
||||
DriverMinor: driverMinor,
|
||||
}
|
||||
|
||||
// If the user wants to filter to a subset of devices, filter out if we aren't a match
|
||||
@@ -305,7 +276,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
libDir, err = AMDValidateLibDir()
|
||||
if err != nil {
|
||||
slog.Warn("unable to verify rocm library, will use cpu", "error", err)
|
||||
return nil
|
||||
return []GpuInfo{}
|
||||
}
|
||||
}
|
||||
gpuInfo.DependencyPath = libDir
|
||||
@@ -316,7 +287,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to lookup supported GFX types, falling back to CPU mode", "error", err)
|
||||
return nil
|
||||
return []GpuInfo{}
|
||||
}
|
||||
slog.Debug("rocm supported GPUs", "types", supported)
|
||||
}
|
||||
@@ -333,11 +304,6 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
slog.Info("skipping rocm gfx compatibility check", "HSA_OVERRIDE_GFX_VERSION", gfxOverride)
|
||||
}
|
||||
|
||||
// Check for env var workarounds
|
||||
if name == "1002:687f" { // Vega RX 56
|
||||
gpuInfo.EnvWorkarounds = append(gpuInfo.EnvWorkarounds, [2]string{"HSA_ENABLE_SDMA", "0"})
|
||||
}
|
||||
|
||||
// The GPU has passed all the verification steps and is supported
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
@@ -412,31 +378,3 @@ func AMDDriverVersion() (driverMajor, driverMinor int, err error) {
|
||||
}
|
||||
return driverMajor, driverMinor, nil
|
||||
}
|
||||
|
||||
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
if len(gpus) == 0 {
|
||||
return nil
|
||||
}
|
||||
for i := range gpus {
|
||||
usedMemory, err := getFreeMemory(gpus[i].usedFilepath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(gpus[i].TotalMemory-usedMemory))
|
||||
gpus[i].FreeMemory = gpus[i].TotalMemory - usedMemory
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func getFreeMemory(usedFile string) (uint64, error) {
|
||||
buf, err := os.ReadFile(usedFile)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to read sysfs node %s %w", usedFile, err)
|
||||
}
|
||||
usedMemory, err := strconv.ParseUint(strings.TrimSpace(string(buf)), 10, 64)
|
||||
if err != nil {
|
||||
slog.Debug("failed to parse sysfs node", "file", usedFile, "error", err)
|
||||
return 0, fmt.Errorf("failed to parse sysfs node %s %w", usedFile, err)
|
||||
}
|
||||
return usedMemory, nil
|
||||
}
|
||||
|
||||
@@ -7,10 +7,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
@@ -26,8 +24,8 @@ var (
|
||||
RocmStandardLocations = []string{"C:\\Program Files\\AMD\\ROCm\\5.7\\bin"} // TODO glob?
|
||||
)
|
||||
|
||||
func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
resp := []RocmGPUInfo{}
|
||||
func AMDGetGPUInfo() []GpuInfo {
|
||||
resp := []GpuInfo{}
|
||||
hl, err := NewHipLib()
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
@@ -54,7 +52,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
}
|
||||
|
||||
var supported []string
|
||||
gfxOverride := envconfig.HsaOverrideGfxVersion
|
||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||
if gfxOverride == "" {
|
||||
supported, err = GetSupportedGFX(libDir)
|
||||
if err != nil {
|
||||
@@ -67,7 +65,7 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
|
||||
slog.Debug("detected hip devices", "count", count)
|
||||
// TODO how to determine the underlying device ID when visible devices is causing this to subset?
|
||||
for i := range count {
|
||||
for i := 0; i < count; i++ {
|
||||
err = hl.HipSetDevice(i)
|
||||
if err != nil {
|
||||
slog.Warn("set device", "id", i, "error", err)
|
||||
@@ -119,24 +117,21 @@ func AMDGetGPUInfo() []RocmGPUInfo {
|
||||
// v5.7 only reports VRAM used by this process, so it's completely wrong and unusable
|
||||
slog.Debug("amdgpu memory", "gpu", i, "total", format.HumanBytes2(totalMemory))
|
||||
slog.Debug("amdgpu memory", "gpu", i, "available", format.HumanBytes2(freeMemory))
|
||||
gpuInfo := RocmGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: freeMemory,
|
||||
},
|
||||
ID: strconv.Itoa(i), // TODO this is probably wrong if we specify visible devices
|
||||
DependencyPath: libDir,
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
|
||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
||||
// DriverMajor: driverMajor,
|
||||
// DriverMinor: driverMinor,
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "rocm",
|
||||
memInfo: memInfo{
|
||||
TotalMemory: totalMemory,
|
||||
FreeMemory: freeMemory,
|
||||
},
|
||||
index: i,
|
||||
ID: fmt.Sprintf("%d", i), // TODO this is probably wrong if we specify visible devices
|
||||
DependencyPath: libDir,
|
||||
MinimumMemory: rocmMinimumMemory,
|
||||
Name: name,
|
||||
Compute: gfx,
|
||||
|
||||
// TODO - this information isn't accurate on windows, so don't report it until we find the right way to retrieve
|
||||
// DriverMajor: driverMajor,
|
||||
// DriverMinor: driverMinor,
|
||||
}
|
||||
|
||||
resp = append(resp, gpuInfo)
|
||||
@@ -164,30 +159,3 @@ func AMDValidateLibDir() (string, error) {
|
||||
slog.Warn("amdgpu detected, but no compatible rocm library found. Please install ROCm")
|
||||
return "", fmt.Errorf("no suitable rocm found, falling back to CPU")
|
||||
}
|
||||
|
||||
func (gpus RocmGPUInfoList) RefreshFreeMemory() error {
|
||||
if len(gpus) == 0 {
|
||||
return nil
|
||||
}
|
||||
hl, err := NewHipLib()
|
||||
if err != nil {
|
||||
slog.Debug(err.Error())
|
||||
return nil
|
||||
}
|
||||
defer hl.Release()
|
||||
|
||||
for i := range gpus {
|
||||
err := hl.HipSetDevice(gpus[i].index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
freeMemory, _, err := hl.HipMemGetInfo()
|
||||
if err != nil {
|
||||
slog.Warn("get mem info", "id", i, "error", err)
|
||||
continue
|
||||
}
|
||||
slog.Debug("updating rocm free memory", "gpu", gpus[i].ID, "name", gpus[i].Name, "before", format.HumanBytes2(gpus[i].FreeMemory), "now", format.HumanBytes2(freeMemory))
|
||||
gpus[i].FreeMemory = freeMemory
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -80,7 +80,7 @@ func cleanupTmpDirs() {
|
||||
if err == nil {
|
||||
pid, err := strconv.Atoi(string(raw))
|
||||
if err == nil {
|
||||
if proc, err := os.FindProcess(pid); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
if proc, err := os.FindProcess(int(pid)); err == nil && !errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
// Another running ollama, ignore this tmpdir
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
|
||||
"golang.org/x/sys/cpu"
|
||||
)
|
||||
|
||||
func GetCPUCapability() CPUCapability {
|
||||
func GetCPUVariant() string {
|
||||
if cpu.X86.HasAVX2 {
|
||||
return CPUCapabilityAVX2
|
||||
slog.Debug("CPU has AVX2")
|
||||
return "avx2"
|
||||
}
|
||||
if cpu.X86.HasAVX {
|
||||
return CPUCapabilityAVX
|
||||
slog.Debug("CPU has AVX")
|
||||
return "avx"
|
||||
}
|
||||
slog.Debug("CPU does not have vector extensions")
|
||||
// else LCD
|
||||
return CPUCapabilityNone
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -18,4 +18,5 @@ func cudaGetVisibleDevicesEnv(gpuInfo []GpuInfo) (string, string) {
|
||||
ids = append(ids, info.ID)
|
||||
}
|
||||
return "CUDA_VISIBLE_DEVICES", strings.Join(ids, ",")
|
||||
|
||||
}
|
||||
|
||||
497
gpu/gpu.go
497
gpu/gpu.go
@@ -24,37 +24,19 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
type cudaHandles struct {
|
||||
type handles struct {
|
||||
deviceCount int
|
||||
cudart *C.cudart_handle_t
|
||||
nvcuda *C.nvcuda_handle_t
|
||||
nvml *C.nvml_handle_t
|
||||
}
|
||||
|
||||
type oneapiHandles struct {
|
||||
oneapi *C.oneapi_handle_t
|
||||
deviceCount int
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
// TODO OneAPI minimum memory
|
||||
)
|
||||
|
||||
var (
|
||||
gpuMutex sync.Mutex
|
||||
bootstrapped bool
|
||||
cpuCapability CPUCapability
|
||||
cpus []CPUInfo
|
||||
cudaGPUs []CudaGPUInfo
|
||||
nvcudaLibPath string
|
||||
cudartLibPath string
|
||||
oneapiLibPath string
|
||||
nvmlLibPath string
|
||||
rocmGPUs []RocmGPUInfo
|
||||
oneapiGPUs []OneapiGPUInfo
|
||||
)
|
||||
var gpuMutex sync.Mutex
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
@@ -64,113 +46,113 @@ var RocmComputeMin = 9
|
||||
// TODO find a better way to detect iGPU instead of minimum memory
|
||||
const IGPUMemLimit = 1 * format.GibiByte // 512G is what they typically report, so anything less than 1G must be iGPU
|
||||
|
||||
var CudartLinuxGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libcudart.so*",
|
||||
"/usr/lib/wsl/lib/libcudart.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcudart.so*",
|
||||
"/opt/cuda/lib64/libcudart.so*",
|
||||
"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libcudart.so*",
|
||||
"/usr/local/cuda/lib*/libcudart.so*",
|
||||
"/usr/lib*/libcudart.so*",
|
||||
"/usr/local/lib*/libcudart.so*",
|
||||
}
|
||||
|
||||
var CudartWindowsGlobs = []string{
|
||||
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
||||
}
|
||||
|
||||
var NvcudaLinuxGlobs = []string{
|
||||
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/libcuda.so*",
|
||||
"/usr/lib/wsl/lib/libcuda.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcuda.so*",
|
||||
"/opt/cuda/lib*/libcuda.so*",
|
||||
"/usr/local/cuda/lib*/libcuda.so*",
|
||||
"/usr/lib*/libcuda.so*",
|
||||
"/usr/local/lib*/libcuda.so*",
|
||||
}
|
||||
|
||||
var NvcudaWindowsGlobs = []string{
|
||||
"c:\\windows\\system*\\nvcuda.dll",
|
||||
}
|
||||
|
||||
var OneapiWindowsGlobs = []string{
|
||||
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
||||
}
|
||||
|
||||
var OneapiLinuxGlobs = []string{
|
||||
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
|
||||
"/usr/lib*/libze_intel_gpu.so*",
|
||||
}
|
||||
|
||||
// Jetson devices have JETSON_JETPACK="x.y.z" factory set to the Jetpack version installed.
|
||||
// Included to drive logic for reducing Ollama-allocated overhead on L4T/Jetson devices.
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initCudaHandles() *cudaHandles {
|
||||
func initGPUHandles() *handles {
|
||||
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
cHandles := &cudaHandles{}
|
||||
// Short Circuit if we already know which library to use
|
||||
if nvmlLibPath != "" {
|
||||
cHandles.nvml, _ = LoadNVMLMgmt([]string{nvmlLibPath})
|
||||
return cHandles
|
||||
}
|
||||
if nvcudaLibPath != "" {
|
||||
cHandles.deviceCount, cHandles.nvcuda, _ = LoadNVCUDAMgmt([]string{nvcudaLibPath})
|
||||
return cHandles
|
||||
}
|
||||
if cudartLibPath != "" {
|
||||
cHandles.deviceCount, cHandles.cudart, _ = LoadCUDARTMgmt([]string{cudartLibPath})
|
||||
return cHandles
|
||||
}
|
||||
|
||||
slog.Debug("searching for GPU discovery libraries for NVIDIA")
|
||||
gpuHandles := &handles{}
|
||||
var cudartMgmtName string
|
||||
var cudartMgmtPatterns []string
|
||||
var nvcudaMgmtName string
|
||||
var nvcudaMgmtPatterns []string
|
||||
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtPatterns := NvcudaGlobs
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", CudartMgmtName)}
|
||||
}
|
||||
tmpDir, _ := PayloadsDir()
|
||||
if tmpDir != "" {
|
||||
// TODO - add "payloads" for subprocess
|
||||
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", CudartMgmtName)}
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartGlobs...)
|
||||
|
||||
if len(NvmlGlobs) > 0 {
|
||||
nvmlLibPaths := FindGPULibs(NvmlMgmtName, NvmlGlobs)
|
||||
if len(nvmlLibPaths) > 0 {
|
||||
nvml, libPath := LoadNVMLMgmt(nvmlLibPaths)
|
||||
if nvml != nil {
|
||||
slog.Debug("nvidia-ml loaded", "library", libPath)
|
||||
cHandles.nvml = nvml
|
||||
nvmlLibPath = libPath
|
||||
}
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
cudartMgmtName = "cudart64_*.dll"
|
||||
localAppData := os.Getenv("LOCALAPPDATA")
|
||||
cudartMgmtPatterns = []string{filepath.Join(localAppData, "Programs", "Ollama", cudartMgmtName)}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartWindowsGlobs...)
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtName = "nvcuda.dll"
|
||||
nvcudaMgmtPatterns = NvcudaWindowsGlobs
|
||||
case "linux":
|
||||
cudartMgmtName = "libcudart.so*"
|
||||
if tmpDir != "" {
|
||||
// TODO - add "payloads" for subprocess
|
||||
cudartMgmtPatterns = []string{filepath.Join(tmpDir, "cuda*", cudartMgmtName)}
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||
// Aligned with driver, we can't carry as payloads
|
||||
nvcudaMgmtName = "libcuda.so*"
|
||||
nvcudaMgmtPatterns = NvcudaLinuxGlobs
|
||||
default:
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
nvcudaLibPaths := FindGPULibs(NvcudaMgmtName, nvcudaMgmtPatterns)
|
||||
slog.Debug("Detecting GPUs")
|
||||
nvcudaLibPaths := FindGPULibs(nvcudaMgmtName, nvcudaMgmtPatterns)
|
||||
if len(nvcudaLibPaths) > 0 {
|
||||
deviceCount, nvcuda, libPath := LoadNVCUDAMgmt(nvcudaLibPaths)
|
||||
if nvcuda != nil {
|
||||
slog.Debug("detected GPUs", "count", deviceCount, "library", libPath)
|
||||
cHandles.nvcuda = nvcuda
|
||||
cHandles.deviceCount = deviceCount
|
||||
nvcudaLibPath = libPath
|
||||
return cHandles
|
||||
gpuHandles.nvcuda = nvcuda
|
||||
gpuHandles.deviceCount = deviceCount
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
cudartLibPaths := FindGPULibs(CudartMgmtName, cudartMgmtPatterns)
|
||||
cudartLibPaths := FindGPULibs(cudartMgmtName, cudartMgmtPatterns)
|
||||
if len(cudartLibPaths) > 0 {
|
||||
deviceCount, cudart, libPath := LoadCUDARTMgmt(cudartLibPaths)
|
||||
if cudart != nil {
|
||||
slog.Debug("detected GPUs", "library", libPath, "count", deviceCount)
|
||||
cHandles.cudart = cudart
|
||||
cHandles.deviceCount = deviceCount
|
||||
cudartLibPath = libPath
|
||||
return cHandles
|
||||
gpuHandles.cudart = cudart
|
||||
gpuHandles.deviceCount = deviceCount
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
return cHandles
|
||||
}
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initOneAPIHandles() *oneapiHandles {
|
||||
oHandles := &oneapiHandles{}
|
||||
|
||||
// Short Circuit if we already know which library to use
|
||||
if oneapiLibPath != "" {
|
||||
oHandles.deviceCount, oHandles.oneapi, _ = LoadOneapiMgmt([]string{oneapiLibPath})
|
||||
return oHandles
|
||||
}
|
||||
|
||||
oneapiLibPaths := FindGPULibs(OneapiMgmtName, OneapiGlobs)
|
||||
if len(oneapiLibPaths) > 0 {
|
||||
oHandles.deviceCount, oHandles.oneapi, oneapiLibPath = LoadOneapiMgmt(oneapiLibPaths)
|
||||
}
|
||||
|
||||
return oHandles
|
||||
}
|
||||
|
||||
func GetCPUInfo() GpuInfoList {
|
||||
gpuMutex.Lock()
|
||||
if !bootstrapped {
|
||||
gpuMutex.Unlock()
|
||||
GetGPUInfo()
|
||||
} else {
|
||||
gpuMutex.Unlock()
|
||||
}
|
||||
return GpuInfoList{cpus[0].GpuInfo}
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfoList {
|
||||
@@ -178,247 +160,112 @@ func GetGPUInfo() GpuInfoList {
|
||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
needRefresh := true
|
||||
var cHandles *cudaHandles
|
||||
var oHandles *oneapiHandles
|
||||
|
||||
gpuHandles := initGPUHandles()
|
||||
defer func() {
|
||||
if cHandles != nil {
|
||||
if cHandles.cudart != nil {
|
||||
C.cudart_release(*cHandles.cudart)
|
||||
}
|
||||
if cHandles.nvcuda != nil {
|
||||
C.nvcuda_release(*cHandles.nvcuda)
|
||||
}
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_release(*cHandles.nvml)
|
||||
}
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_release(*gpuHandles.cudart)
|
||||
}
|
||||
if oHandles != nil {
|
||||
if oHandles.oneapi != nil {
|
||||
// TODO - is this needed?
|
||||
C.oneapi_release(*oHandles.oneapi)
|
||||
}
|
||||
if gpuHandles.nvcuda != nil {
|
||||
C.nvcuda_release(*gpuHandles.nvcuda)
|
||||
}
|
||||
}()
|
||||
|
||||
if !bootstrapped {
|
||||
slog.Debug("Detecting GPUs")
|
||||
needRefresh = false
|
||||
cpuCapability = GetCPUCapability()
|
||||
var memInfo C.mem_info_t
|
||||
|
||||
mem, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
}
|
||||
cpus = []CPUInfo{CPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
memInfo: mem,
|
||||
Library: "cpu",
|
||||
Variant: cpuCapability,
|
||||
ID: "0",
|
||||
},
|
||||
}}
|
||||
|
||||
// Fallback to CPU mode if we're lacking required vector extensions on x86
|
||||
if cpuCapability < GPURunnerCPUCapability && runtime.GOARCH == "amd64" {
|
||||
slog.Warn("CPU does not have minimum vector extensions, GPU inference disabled", "required", GPURunnerCPUCapability, "detected", cpuCapability)
|
||||
bootstrapped = true
|
||||
// No need to do any GPU discovery, since we can't run on them
|
||||
return GpuInfoList{cpus[0].GpuInfo}
|
||||
}
|
||||
|
||||
// On windows we bundle the nvidia library one level above the runner dir
|
||||
depPath := ""
|
||||
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||
depPath = filepath.Dir(envconfig.RunnersDir)
|
||||
}
|
||||
|
||||
// Load ALL libraries
|
||||
cHandles = initCudaHandles()
|
||||
|
||||
// NVIDIA
|
||||
for i := range cHandles.deviceCount {
|
||||
if cHandles.cudart != nil || cHandles.nvcuda != nil {
|
||||
gpuInfo := CudaGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "cuda",
|
||||
},
|
||||
index: i,
|
||||
}
|
||||
var driverMajor int
|
||||
var driverMinor int
|
||||
if cHandles.cudart != nil {
|
||||
C.cudart_bootstrap(*cHandles.cudart, C.int(i), &memInfo)
|
||||
} else {
|
||||
C.nvcuda_bootstrap(*cHandles.nvcuda, C.int(i), &memInfo)
|
||||
driverMajor = int(cHandles.nvcuda.driver_major)
|
||||
driverMinor = int(cHandles.nvcuda.driver_minor)
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
continue
|
||||
}
|
||||
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||
continue
|
||||
}
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DependencyPath = depPath
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.DriverMajor = driverMajor
|
||||
gpuInfo.DriverMinor = driverMinor
|
||||
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
cudaGPUs = append(cudaGPUs, gpuInfo)
|
||||
}
|
||||
}
|
||||
|
||||
// Intel
|
||||
oHandles = initOneAPIHandles()
|
||||
for d := 0; oHandles.oneapi != nil && d < int(oHandles.oneapi.num_drivers); d++ {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with driver count", "count", int(oHandles.oneapi.num_drivers))
|
||||
continue
|
||||
}
|
||||
devCount := C.oneapi_get_device_count(*oHandles.oneapi, C.int(d))
|
||||
for i := range devCount {
|
||||
gpuInfo := OneapiGPUInfo{
|
||||
GpuInfo: GpuInfo{
|
||||
Library: "oneapi",
|
||||
},
|
||||
driverIndex: d,
|
||||
gpuIndex: int(i),
|
||||
}
|
||||
// TODO - split bootstrapping from updating free memory
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(d), i, &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
// TODO dependency path?
|
||||
oneapiGPUs = append(oneapiGPUs, gpuInfo)
|
||||
}
|
||||
}
|
||||
|
||||
rocmGPUs = AMDGetGPUInfo()
|
||||
bootstrapped = true
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
|
||||
}
|
||||
|
||||
// For detected GPUs, load library if not loaded
|
||||
// On windows we bundle the nvidia library one level above the runner dir
|
||||
depPath := ""
|
||||
if runtime.GOOS == "windows" && envconfig.RunnersDir != "" {
|
||||
depPath = filepath.Dir(envconfig.RunnersDir)
|
||||
}
|
||||
|
||||
// Refresh free memory usage
|
||||
if needRefresh {
|
||||
mem, err := GetCPUMem()
|
||||
if err != nil {
|
||||
slog.Warn("error looking up system memory", "error", err)
|
||||
} else {
|
||||
slog.Debug("updating system memory data",
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(cpus[0].TotalMemory),
|
||||
"free", format.HumanBytes2(cpus[0].FreeMemory),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(mem.TotalMemory),
|
||||
"free", format.HumanBytes2(mem.FreeMemory),
|
||||
),
|
||||
)
|
||||
cpus[0].FreeMemory = mem.FreeMemory
|
||||
}
|
||||
var memInfo C.mem_info_t
|
||||
resp := []GpuInfo{}
|
||||
|
||||
var memInfo C.mem_info_t
|
||||
if cHandles == nil && len(cudaGPUs) > 0 {
|
||||
cHandles = initCudaHandles()
|
||||
// NVIDIA first
|
||||
for i := 0; i < gpuHandles.deviceCount; i++ {
|
||||
// TODO once we support CPU compilation variants of GPU libraries refine this...
|
||||
if cpuVariant == "" && runtime.GOARCH == "amd64" {
|
||||
continue
|
||||
}
|
||||
for i, gpu := range cudaGPUs {
|
||||
if cHandles.nvml != nil {
|
||||
C.nvml_get_free(*cHandles.nvml, C.int(gpu.index), &memInfo.free, &memInfo.total, &memInfo.used)
|
||||
} else if cHandles.cudart != nil {
|
||||
C.cudart_bootstrap(*cHandles.cudart, C.int(gpu.index), &memInfo)
|
||||
} else if cHandles.nvcuda != nil {
|
||||
C.nvcuda_get_free(*cHandles.nvcuda, C.int(gpu.index), &memInfo.free, &memInfo.total)
|
||||
memInfo.used = memInfo.total - memInfo.free
|
||||
if gpuHandles.cudart != nil || gpuHandles.nvcuda != nil {
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "cuda",
|
||||
}
|
||||
var driverMajor int
|
||||
var driverMinor int
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_check_vram(*gpuHandles.cudart, C.int(i), &memInfo)
|
||||
} else {
|
||||
// shouldn't happen
|
||||
slog.Warn("no valid cuda library loaded to refresh vram usage")
|
||||
break
|
||||
C.nvcuda_check_vram(*gpuHandles.nvcuda, C.int(i), &memInfo)
|
||||
driverMajor = int(gpuHandles.nvcuda.driver_major)
|
||||
driverMinor = int(gpuHandles.nvcuda.driver_minor)
|
||||
}
|
||||
if memInfo.err != nil {
|
||||
slog.Warn("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
slog.Info("error looking up nvidia GPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
continue
|
||||
}
|
||||
if memInfo.free == 0 {
|
||||
slog.Warn("error looking up nvidia GPU memory")
|
||||
if memInfo.major < CudaComputeMin[0] || (memInfo.major == CudaComputeMin[0] && memInfo.minor < CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[%d] CUDA GPU is too old. Compute Capability detected: %d.%d", i, memInfo.major, memInfo.minor))
|
||||
continue
|
||||
}
|
||||
slog.Debug("updating cuda memory data",
|
||||
"gpu", gpu.ID,
|
||||
"name", gpu.Name,
|
||||
slog.Group(
|
||||
"before",
|
||||
"total", format.HumanBytes2(gpu.TotalMemory),
|
||||
"free", format.HumanBytes2(gpu.FreeMemory),
|
||||
),
|
||||
slog.Group(
|
||||
"now",
|
||||
"total", format.HumanBytes2(uint64(memInfo.total)),
|
||||
"free", format.HumanBytes2(uint64(memInfo.free)),
|
||||
"used", format.HumanBytes2(uint64(memInfo.used)),
|
||||
),
|
||||
)
|
||||
cudaGPUs[i].FreeMemory = uint64(memInfo.free)
|
||||
}
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
gpuInfo.Compute = fmt.Sprintf("%d.%d", memInfo.major, memInfo.minor)
|
||||
gpuInfo.MinimumMemory = cudaMinimumMemory
|
||||
gpuInfo.DependencyPath = depPath
|
||||
gpuInfo.Name = C.GoString(&memInfo.gpu_name[0])
|
||||
gpuInfo.DriverMajor = int(driverMajor)
|
||||
gpuInfo.DriverMinor = int(driverMinor)
|
||||
|
||||
if oHandles == nil && len(oneapiGPUs) > 0 {
|
||||
oHandles = initOneAPIHandles()
|
||||
}
|
||||
for i, gpu := range oneapiGPUs {
|
||||
if oHandles.oneapi == nil {
|
||||
// shouldn't happen
|
||||
slog.Warn("nil oneapi handle with device count", "count", oHandles.deviceCount)
|
||||
continue
|
||||
}
|
||||
C.oneapi_check_vram(*oHandles.oneapi, C.int(gpu.driverIndex), C.int(gpu.gpuIndex), &memInfo)
|
||||
// TODO - convert this to MinimumMemory based on testing...
|
||||
var totalFreeMem float64 = float64(memInfo.free) * 0.95 // work-around: leave some reserve vram for mkl lib used in ggml-sycl backend.
|
||||
memInfo.free = C.uint64_t(totalFreeMem)
|
||||
oneapiGPUs[i].FreeMemory = uint64(memInfo.free)
|
||||
}
|
||||
|
||||
err = RocmGPUInfoList(rocmGPUs).RefreshFreeMemory()
|
||||
if err != nil {
|
||||
slog.Debug("problem refreshing ROCm free memory", "error", err)
|
||||
// TODO potentially sort on our own algorithm instead of what the underlying GPU library does...
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
}
|
||||
|
||||
resp := []GpuInfo{}
|
||||
for _, gpu := range cudaGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
for _, gpu := range rocmGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
for _, gpu := range oneapiGPUs {
|
||||
resp = append(resp, gpu.GpuInfo)
|
||||
}
|
||||
// Then AMD
|
||||
resp = append(resp, AMDGetGPUInfo()...)
|
||||
|
||||
if len(resp) == 0 {
|
||||
resp = append(resp, cpus[0].GpuInfo)
|
||||
C.cpu_check_ram(&memInfo)
|
||||
if memInfo.err != nil {
|
||||
slog.Info("error looking up CPU memory", "error", C.GoString(memInfo.err))
|
||||
C.free(unsafe.Pointer(memInfo.err))
|
||||
return resp
|
||||
}
|
||||
gpuInfo := GpuInfo{
|
||||
Library: "cpu",
|
||||
Variant: cpuVariant,
|
||||
}
|
||||
gpuInfo.TotalMemory = uint64(memInfo.total)
|
||||
gpuInfo.FreeMemory = uint64(memInfo.free)
|
||||
gpuInfo.ID = C.GoString(&memInfo.gpu_id[0])
|
||||
|
||||
resp = append(resp, gpuInfo)
|
||||
}
|
||||
|
||||
return resp
|
||||
}
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var ret memInfo
|
||||
var info C.mem_info_t
|
||||
C.cpu_check_ram(&info)
|
||||
if info.err != nil {
|
||||
defer C.free(unsafe.Pointer(info.err))
|
||||
return ret, fmt.Errorf(C.GoString(info.err))
|
||||
}
|
||||
ret.FreeMemory = uint64(info.free)
|
||||
ret.TotalMemory = uint64(info.total)
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
|
||||
var ldPaths []string
|
||||
@@ -449,7 +296,6 @@ func FindGPULibs(baseLibName string, defaultPatterns []string) []string {
|
||||
// Nvidia PhysX known to return bogus results
|
||||
if strings.Contains(pattern, "PhysX") {
|
||||
slog.Debug("skipping PhysX cuda library path", "path", pattern)
|
||||
continue
|
||||
}
|
||||
// Ignore glob discovery errors
|
||||
matches, _ := filepath.Glob(pattern)
|
||||
@@ -515,26 +361,8 @@ func LoadNVCUDAMgmt(nvcudaLibPaths []string) (int, *C.nvcuda_handle_t, string) {
|
||||
return 0, nil, ""
|
||||
}
|
||||
|
||||
func LoadNVMLMgmt(nvmlLibPaths []string) (*C.nvml_handle_t, string) {
|
||||
var resp C.nvml_init_resp_t
|
||||
resp.ch.verbose = getVerboseState()
|
||||
for _, libPath := range nvmlLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
defer C.free(unsafe.Pointer(lib))
|
||||
C.nvml_init(lib, &resp)
|
||||
if resp.err != nil {
|
||||
slog.Info(fmt.Sprintf("Unable to load NVML management library %s: %s", libPath, C.GoString(resp.err)))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
return &resp.ch, libPath
|
||||
}
|
||||
}
|
||||
return nil, ""
|
||||
}
|
||||
|
||||
func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
|
||||
var resp C.oneapi_init_resp_t
|
||||
num_devices := 0
|
||||
resp.oh.verbose = getVerboseState()
|
||||
for _, libPath := range oneapiLibPaths {
|
||||
lib := C.CString(libPath)
|
||||
@@ -544,10 +372,7 @@ func LoadOneapiMgmt(oneapiLibPaths []string) (int, *C.oneapi_handle_t, string) {
|
||||
slog.Debug("Unable to load oneAPI management library", "library", libPath, "error", C.GoString(resp.err))
|
||||
C.free(unsafe.Pointer(resp.err))
|
||||
} else {
|
||||
for i := range resp.oh.num_drivers {
|
||||
num_devices += int(C.oneapi_get_device_count(resp.oh, C.int(i)))
|
||||
}
|
||||
return num_devices, &resp.oh, libPath
|
||||
return int(resp.num_devices), &resp.oh, libPath
|
||||
}
|
||||
}
|
||||
return 0, nil, ""
|
||||
|
||||
@@ -24,7 +24,7 @@ func GetGPUInfo() GpuInfoList {
|
||||
return []GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
Variant: GetCPUCapability(),
|
||||
Variant: GetCPUVariant(),
|
||||
memInfo: mem,
|
||||
},
|
||||
}
|
||||
@@ -42,17 +42,6 @@ func GetGPUInfo() GpuInfoList {
|
||||
return []GpuInfo{info}
|
||||
}
|
||||
|
||||
func GetCPUInfo() GpuInfoList {
|
||||
mem, _ := GetCPUMem()
|
||||
return []GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
Variant: GetCPUCapability(),
|
||||
memInfo: mem,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
|
||||
@@ -47,7 +47,6 @@ typedef struct mem_info {
|
||||
char gpu_name[GPU_NAME_LEN];
|
||||
uint64_t total;
|
||||
uint64_t free;
|
||||
uint64_t used;
|
||||
|
||||
// Compute Capability
|
||||
int major;
|
||||
@@ -63,7 +62,6 @@ void cpu_check_ram(mem_info_t *resp);
|
||||
|
||||
#include "gpu_info_cudart.h"
|
||||
#include "gpu_info_nvcuda.h"
|
||||
#include "gpu_info_nvml.h"
|
||||
#include "gpu_info_oneapi.h"
|
||||
|
||||
#endif // __GPU_INFO_H__
|
||||
|
||||
45
gpu/gpu_info_cpu.c
Normal file
45
gpu/gpu_info_cpu.c
Normal file
@@ -0,0 +1,45 @@
|
||||
#include "gpu_info.h"
|
||||
// Fallbacks for CPU mode
|
||||
|
||||
#ifdef _WIN32
|
||||
#include <sysinfoapi.h>
|
||||
void cpu_check_ram(mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
MEMORYSTATUSEX info;
|
||||
info.dwLength = sizeof(info);
|
||||
if (GlobalMemoryStatusEx(&info) != 0) {
|
||||
resp->total = info.ullTotalPhys;
|
||||
resp->free = info.ullAvailPhys;
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||
} else {
|
||||
resp->err = LOAD_ERR();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#elif __linux__
|
||||
#include <errno.h>
|
||||
#include <string.h>
|
||||
#include <sys/sysinfo.h>
|
||||
void cpu_check_ram(mem_info_t *resp) {
|
||||
struct sysinfo info;
|
||||
resp->err = NULL;
|
||||
if (sysinfo(&info) != 0) {
|
||||
resp->err = strdup(strerror(errno));
|
||||
} else {
|
||||
resp->total = info.totalram * info.mem_unit;
|
||||
resp->free = info.freeram * info.mem_unit;
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "0");
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
#elif __APPLE__
|
||||
// TODO consider an Apple implementation that does something useful
|
||||
// mem_info_t cpu_check_ram() {
|
||||
// mem_info_t resp = {0, 0, NULL};
|
||||
// return resp;
|
||||
// }
|
||||
#else
|
||||
#error "Unsupported platform"
|
||||
#endif
|
||||
@@ -94,7 +94,7 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
}
|
||||
|
||||
|
||||
void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
void cudart_check_vram(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
cudartMemory_t memInfo = {0,0,0};
|
||||
cudartReturn_t ret;
|
||||
@@ -166,11 +166,9 @@ void cudart_bootstrap(cudart_handle_t h, int i, mem_info_t *resp) {
|
||||
|
||||
resp->total = memInfo.total;
|
||||
resp->free = memInfo.free;
|
||||
resp->used = memInfo.used;
|
||||
|
||||
LOG(h.verbose, "[%s] CUDA totalMem %lu\n", resp->gpu_id, resp->total);
|
||||
LOG(h.verbose, "[%s] CUDA freeMem %lu\n", resp->gpu_id, resp->free);
|
||||
LOG(h.verbose, "[%s] CUDA usedMem %lu\n", resp->gpu_id, resp->used);
|
||||
LOG(h.verbose, "[%s] Compute Capability %d.%d\n", resp->gpu_id, resp->major, resp->minor);
|
||||
}
|
||||
|
||||
|
||||
@@ -140,8 +140,7 @@ typedef struct cudart_init_resp {
|
||||
} cudart_init_resp_t;
|
||||
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||
void cudart_bootstrap(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||
// TODO - if we keep this library longer term, add cudart_get_free
|
||||
void cudart_check_vram(cudart_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void cudart_release(cudart_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_CUDART_H__
|
||||
|
||||
@@ -96,7 +96,7 @@ void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp) {
|
||||
}
|
||||
|
||||
const int buflen = 256;
|
||||
void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
void nvcuda_check_vram(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
resp->err = NULL;
|
||||
nvcudaMemory_t memInfo = {0,0};
|
||||
CUresult ret;
|
||||
@@ -168,7 +168,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
// To get memory we have to set (and release) a context
|
||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
snprintf(buf, buflen, "nvcuda failed to get device context %d", ret);
|
||||
snprintf(buf, buflen, "nvcuda failed to get primary device context %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
@@ -193,42 +193,7 @@ void nvcuda_bootstrap(nvcuda_handle_t h, int i, mem_info_t *resp) {
|
||||
|
||||
ret = (*h.cuCtxDestroy)(ctx);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda failed to release device context %d", ret);
|
||||
}
|
||||
}
|
||||
|
||||
void nvcuda_get_free(nvcuda_handle_t h, int i, uint64_t *free, uint64_t *total) {
|
||||
CUresult ret;
|
||||
CUcontext ctx = NULL;
|
||||
CUdevice device = -1;
|
||||
*free = 0;
|
||||
*total = 0;
|
||||
|
||||
ret = (*h.cuDeviceGet)(&device, i);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda device failed to initialize");
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
// To get memory we have to set (and release) a context
|
||||
ret = (*h.cuCtxCreate_v3)(&ctx, NULL, 0, 0, device);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda failed to get device context %d", ret);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuMemGetInfo_v2)(free, total);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda device memory info lookup failure %d", ret);
|
||||
// Best effort on failure...
|
||||
(*h.cuCtxDestroy)(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.cuCtxDestroy)(ctx);
|
||||
if (ret != CUDA_SUCCESS) {
|
||||
LOG(1, "nvcuda failed to release device context %d", ret);
|
||||
LOG(1, "nvcuda failed to release primary device context %d", ret);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -67,8 +67,7 @@ typedef struct nvcuda_init_resp {
|
||||
} nvcuda_init_resp_t;
|
||||
|
||||
void nvcuda_init(char *nvcuda_lib_path, nvcuda_init_resp_t *resp);
|
||||
void nvcuda_bootstrap(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void nvcuda_get_free(nvcuda_handle_t ch, int device_id, uint64_t *free, uint64_t *total);
|
||||
void nvcuda_check_vram(nvcuda_handle_t ch, int device_id, mem_info_t *resp);
|
||||
void nvcuda_release(nvcuda_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVCUDA_H__
|
||||
|
||||
@@ -1,104 +0,0 @@
|
||||
#ifndef __APPLE__ // TODO - maybe consider nvidia support on intel macs?
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "gpu_info_nvml.h"
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp) {
|
||||
nvmlReturn_t ret;
|
||||
resp->err = NULL;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i;
|
||||
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
|
||||
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
|
||||
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
|
||||
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
|
||||
{NULL, NULL},
|
||||
};
|
||||
|
||||
resp->ch.handle = LOAD_LIBRARY(nvml_lib_path, RTLD_LAZY);
|
||||
if (!resp->ch.handle) {
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "library %s load err: %s\n", nvml_lib_path, msg);
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Nvidia GPUs: %s",
|
||||
nvml_lib_path, msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
// LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", nvml_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
// LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
resp->ch.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
|
||||
msg);
|
||||
free(msg);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
ret = (*resp->ch.nvmlInit_v2)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void nvml_get_free(nvml_handle_t h, int device_id, uint64_t *free, uint64_t *total, uint64_t *used) {
|
||||
nvmlDevice_t device;
|
||||
nvmlMemory_t memInfo = {0};
|
||||
nvmlReturn_t ret;
|
||||
ret = (*h.nvmlDeviceGetHandleByIndex)(device_id, &device);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "unable to get device handle %d: %d", device_id, ret);
|
||||
*free = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "device memory info lookup failure %d: %d", device_id, ret);
|
||||
*free = 0;
|
||||
return;
|
||||
}
|
||||
*free = memInfo.free;
|
||||
*total = memInfo.total;
|
||||
*used = memInfo.used;
|
||||
}
|
||||
|
||||
|
||||
void nvml_release(nvml_handle_t h) {
|
||||
LOG(h.verbose, "releasing nvml library\n");
|
||||
nvmlReturn_t ret;
|
||||
ret = (*h.nvmlShutdown)();
|
||||
if (ret != NVML_SUCCESS) {
|
||||
LOG(1, "error during nvmlShutdown %d", ret);
|
||||
}
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -1,48 +0,0 @@
|
||||
#ifndef __APPLE__
|
||||
#ifndef __GPU_INFO_NVML_H__
|
||||
#define __GPU_INFO_NVML_H__
|
||||
#include "gpu_info.h"
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum nvmlReturn_enum {
|
||||
NVML_SUCCESS = 0,
|
||||
// Other values omitted for now...
|
||||
} nvmlReturn_t;
|
||||
typedef void *nvmlDevice_t; // Opaque is sufficient
|
||||
typedef struct nvmlMemory_st {
|
||||
unsigned long long total;
|
||||
unsigned long long free;
|
||||
unsigned long long used;
|
||||
} nvmlMemory_t;
|
||||
|
||||
typedef enum nvmlBrandType_enum
|
||||
{
|
||||
NVML_BRAND_UNKNOWN = 0,
|
||||
} nvmlBrandType_t;
|
||||
|
||||
typedef struct nvml_handle {
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
nvmlReturn_t (*nvmlInit_v2)(void);
|
||||
nvmlReturn_t (*nvmlShutdown)(void);
|
||||
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
|
||||
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
|
||||
} nvml_handle_t;
|
||||
|
||||
typedef struct nvml_init_resp {
|
||||
char *err; // If err is non-null handle is invalid
|
||||
nvml_handle_t ch;
|
||||
} nvml_init_resp_t;
|
||||
|
||||
typedef struct nvml_compute_capability {
|
||||
char *err;
|
||||
int major;
|
||||
int minor;
|
||||
} nvml_compute_capability_t;
|
||||
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
||||
void nvml_get_free(nvml_handle_t ch, int device_id, uint64_t *free, uint64_t *total, uint64_t *used);
|
||||
void nvml_release(nvml_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVML_H__
|
||||
#endif // __APPLE__
|
||||
@@ -4,17 +4,15 @@
|
||||
|
||||
#include <string.h>
|
||||
|
||||
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
||||
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp)
|
||||
{
|
||||
ze_result_t ret;
|
||||
resp->err = NULL;
|
||||
resp->oh.devices = NULL;
|
||||
resp->oh.num_devices = NULL;
|
||||
resp->oh.drivers = NULL;
|
||||
resp->oh.num_drivers = 0;
|
||||
const int buflen = 256;
|
||||
char buf[buflen + 1];
|
||||
int i, d, count;
|
||||
struct lookup {
|
||||
int i;
|
||||
struct lookup
|
||||
{
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
@@ -30,7 +28,8 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
||||
};
|
||||
|
||||
resp->oh.handle = LOAD_LIBRARY(oneapi_lib_path, RTLD_LAZY);
|
||||
if (!resp->oh.handle) {
|
||||
if (!resp->oh.handle)
|
||||
{
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(buf, buflen,
|
||||
"Unable to load %s library to query for Intel GPUs: %s\n",
|
||||
@@ -45,12 +44,14 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
||||
"wiring Level-Zero management library functions in %s\n",
|
||||
oneapi_lib_path);
|
||||
|
||||
for (i = 0; l[i].s != NULL; i++) {
|
||||
for (i = 0; l[i].s != NULL; i++)
|
||||
{
|
||||
// TODO once we've squashed the remaining corner cases remove this log
|
||||
LOG(resp->oh.verbose, "dlsym: %s\n", l[i].s);
|
||||
|
||||
*l[i].p = LOAD_SYMBOL(resp->oh.handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
if (!l[i].p)
|
||||
{
|
||||
resp->oh.handle = NULL;
|
||||
char *msg = LOAD_ERR();
|
||||
LOG(resp->oh.verbose, "dlerr: %s\n", msg);
|
||||
@@ -63,67 +64,22 @@ void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp) {
|
||||
}
|
||||
|
||||
ret = (*resp->oh.zesInit)(0);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesInit err: %x\n", ret);
|
||||
snprintf(buf, buflen, "oneapi vram init failure: %x", ret);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
LOG(resp->oh.verbose, "zesInit err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->oh.handle);
|
||||
resp->oh.handle = NULL;
|
||||
snprintf(buf, buflen, "oneapi vram init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
|
||||
count = 0;
|
||||
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get driver count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
LOG(resp->oh.verbose, "oneapi driver count: %d\n", resp->oh.num_drivers);
|
||||
resp->oh.drivers = malloc(resp->oh.num_drivers * sizeof(zes_driver_handle_t));
|
||||
resp->oh.num_devices = malloc(resp->oh.num_drivers * sizeof(uint32_t));
|
||||
memset(&resp->oh.num_devices[0], 0, resp->oh.num_drivers * sizeof(uint32_t));
|
||||
resp->oh.devices =
|
||||
malloc(resp->oh.num_drivers * sizeof(zes_device_handle_t *));
|
||||
ret = (*resp->oh.zesDriverGet)(&resp->oh.num_drivers, &resp->oh.drivers[0]);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDriverGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get driver count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
|
||||
for (d = 0; d < resp->oh.num_drivers; d++) {
|
||||
ret = (*resp->oh.zesDeviceGet)(resp->oh.drivers[d],
|
||||
&resp->oh.num_devices[d], NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
resp->oh.devices[d] =
|
||||
malloc(resp->oh.num_devices[d] * sizeof(zes_device_handle_t));
|
||||
ret = (*resp->oh.zesDeviceGet)(
|
||||
resp->oh.drivers[d], &resp->oh.num_devices[d], resp->oh.devices[d]);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
LOG(resp->oh.verbose, "zesDeviceGet err: %x\n", ret);
|
||||
snprintf(buf, buflen, "unable to get device count: %x", ret);
|
||||
resp->err = strdup(buf);
|
||||
oneapi_release(resp->oh);
|
||||
return;
|
||||
}
|
||||
count += resp->oh.num_devices[d];
|
||||
}
|
||||
(*resp->oh.zesDriverGet)(&resp->num_devices, NULL);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
|
||||
mem_info_t *resp) {
|
||||
void oneapi_check_vram(oneapi_handle_t h, mem_info_t *resp)
|
||||
{
|
||||
ze_result_t ret;
|
||||
resp->err = NULL;
|
||||
uint64_t totalMem = 0;
|
||||
@@ -132,126 +88,127 @@ void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
|
||||
char buf[buflen + 1];
|
||||
int i, d, m;
|
||||
|
||||
if (h.handle == NULL) {
|
||||
if (h.handle == NULL)
|
||||
{
|
||||
resp->err = strdup("Level-Zero handle not initialized");
|
||||
return;
|
||||
}
|
||||
|
||||
if (driver > h.num_drivers || device > h.num_devices[driver]) {
|
||||
resp->err = strdup("driver of device index out of bounds");
|
||||
uint32_t driversCount = 0;
|
||||
ret = (*h.zesDriverGet)(&driversCount, NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
snprintf(buf, buflen, "unable to get driver count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
LOG(h.verbose, "discovered %d Level-Zero drivers\n", driversCount);
|
||||
|
||||
zes_driver_handle_t *allDrivers =
|
||||
malloc(driversCount * sizeof(zes_driver_handle_t));
|
||||
(*h.zesDriverGet)(&driversCount, allDrivers);
|
||||
|
||||
resp->total = 0;
|
||||
resp->free = 0;
|
||||
|
||||
zes_device_ext_properties_t ext_props;
|
||||
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
|
||||
ext_props.pNext = NULL;
|
||||
|
||||
zes_device_properties_t props;
|
||||
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
|
||||
props.pNext = &ext_props;
|
||||
|
||||
ret = (*h.zesDeviceGetProperties)(h.devices[driver][device], &props);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get device properties: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
snprintf(&resp->gpu_name[0], GPU_NAME_LEN, props.modelName);
|
||||
|
||||
// TODO this needs to map to ONEAPI_DEVICE_SELECTOR syntax
|
||||
// (this is probably wrong...)
|
||||
// TODO - the driver isn't included - what if there are multiple drivers?
|
||||
snprintf(&resp->gpu_id[0], GPU_ID_LEN, "%d", device);
|
||||
|
||||
if (h.verbose) {
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover.
|
||||
LOG(h.verbose, "[%d:%d] oneAPI device name: %s\n", driver, device,
|
||||
props.modelName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI brand: %s\n", driver, device,
|
||||
props.brandName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI vendor: %s\n", driver, device,
|
||||
props.vendorName);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI S/N: %s\n", driver, device,
|
||||
props.serialNumber);
|
||||
LOG(h.verbose, "[%d:%d] oneAPI board number: %s\n", driver, device,
|
||||
props.boardNumber);
|
||||
}
|
||||
|
||||
// TODO
|
||||
// Compute Capability equivalent in resp->major, resp->minor, resp->patch
|
||||
|
||||
uint32_t memCount = 0;
|
||||
ret = (*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount,
|
||||
NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to enumerate Level-Zero memory modules: %x",
|
||||
ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
|
||||
|
||||
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
|
||||
(*h.zesDeviceEnumMemoryModules)(h.devices[driver][device], &memCount, mems);
|
||||
|
||||
for (m = 0; m < memCount; m++) {
|
||||
zes_mem_state_t state;
|
||||
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
|
||||
state.pNext = NULL;
|
||||
ret = (*h.zesMemoryGetState)(mems[m], &state);
|
||||
if (ret != ZE_RESULT_SUCCESS) {
|
||||
snprintf(buf, buflen, "unable to get memory state: %x", ret);
|
||||
for (d = 0; d < driversCount; d++)
|
||||
{
|
||||
uint32_t deviceCount = 0;
|
||||
ret = (*h.zesDeviceGet)(allDrivers[d], &deviceCount, NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
snprintf(buf, buflen, "unable to get device count: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
free(mems);
|
||||
free(allDrivers);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total += state.size;
|
||||
resp->free += state.free;
|
||||
}
|
||||
LOG(h.verbose, "discovered %d Level-Zero devices\n", deviceCount);
|
||||
|
||||
free(mems);
|
||||
}
|
||||
zes_device_handle_t *devices =
|
||||
malloc(deviceCount * sizeof(zes_device_handle_t));
|
||||
(*h.zesDeviceGet)(allDrivers[d], &deviceCount, devices);
|
||||
|
||||
void oneapi_release(oneapi_handle_t h) {
|
||||
int d;
|
||||
LOG(h.verbose, "releasing oneapi library\n");
|
||||
for (d = 0; d < h.num_drivers; d++) {
|
||||
if (h.devices != NULL && h.devices[d] != NULL) {
|
||||
free(h.devices[d]);
|
||||
for (i = 0; i < deviceCount; i++)
|
||||
{
|
||||
zes_device_ext_properties_t ext_props;
|
||||
ext_props.stype = ZES_STRUCTURE_TYPE_DEVICE_EXT_PROPERTIES;
|
||||
ext_props.pNext = NULL;
|
||||
|
||||
zes_device_properties_t props;
|
||||
props.stype = ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES;
|
||||
props.pNext = &ext_props;
|
||||
|
||||
ret = (*h.zesDeviceGetProperties)(devices[i], &props);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
snprintf(buf, buflen, "unable to get device properties: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
free(allDrivers);
|
||||
free(devices);
|
||||
return;
|
||||
}
|
||||
|
||||
if (h.verbose)
|
||||
{
|
||||
// When in verbose mode, report more information about
|
||||
// the card we discover.
|
||||
LOG(h.verbose, "[%d] oneAPI device name: %s\n", i,
|
||||
props.modelName);
|
||||
LOG(h.verbose, "[%d] oneAPI brand: %s\n", i,
|
||||
props.brandName);
|
||||
LOG(h.verbose, "[%d] oneAPI vendor: %s\n", i,
|
||||
props.vendorName);
|
||||
LOG(h.verbose, "[%d] oneAPI S/N: %s\n", i,
|
||||
props.serialNumber);
|
||||
LOG(h.verbose, "[%d] oneAPI board number: %s\n", i,
|
||||
props.boardNumber);
|
||||
}
|
||||
|
||||
uint32_t memCount = 0;
|
||||
ret = (*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, NULL);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
snprintf(buf, buflen,
|
||||
"unable to enumerate Level-Zero memory modules: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
free(allDrivers);
|
||||
free(devices);
|
||||
return;
|
||||
}
|
||||
|
||||
LOG(h.verbose, "discovered %d Level-Zero memory modules\n", memCount);
|
||||
|
||||
zes_mem_handle_t *mems = malloc(memCount * sizeof(zes_mem_handle_t));
|
||||
(*h.zesDeviceEnumMemoryModules)(devices[i], &memCount, mems);
|
||||
|
||||
for (m = 0; m < memCount; m++)
|
||||
{
|
||||
zes_mem_state_t state;
|
||||
state.stype = ZES_STRUCTURE_TYPE_MEM_STATE;
|
||||
state.pNext = NULL;
|
||||
ret = (*h.zesMemoryGetState)(mems[m], &state);
|
||||
if (ret != ZE_RESULT_SUCCESS)
|
||||
{
|
||||
snprintf(buf, buflen, "unable to get memory state: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
free(allDrivers);
|
||||
free(devices);
|
||||
free(mems);
|
||||
return;
|
||||
}
|
||||
|
||||
resp->total += state.size;
|
||||
resp->free += state.free;
|
||||
}
|
||||
|
||||
free(mems);
|
||||
}
|
||||
}
|
||||
if (h.devices != NULL) {
|
||||
free(h.devices);
|
||||
h.devices = NULL;
|
||||
}
|
||||
if (h.num_devices != NULL) {
|
||||
free(h.num_devices);
|
||||
h.num_devices = NULL;
|
||||
}
|
||||
if (h.drivers != NULL) {
|
||||
free(h.drivers);
|
||||
h.drivers = NULL;
|
||||
}
|
||||
h.num_drivers = 0;
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
int oneapi_get_device_count(oneapi_handle_t h, int driver) {
|
||||
if (h.handle == NULL || h.num_devices == NULL) {
|
||||
return 0;
|
||||
free(devices);
|
||||
}
|
||||
if (driver > h.num_drivers) {
|
||||
return 0;
|
||||
}
|
||||
return (int)h.num_devices[driver];
|
||||
|
||||
free(allDrivers);
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -9,7 +9,8 @@
|
||||
#define ZE_BIT(_i) (1 << _i)
|
||||
|
||||
// Just enough typedef's to dlopen/dlsym for memory information
|
||||
typedef enum ze_result_t {
|
||||
typedef enum ze_result_t
|
||||
{
|
||||
ZE_RESULT_SUCCESS = 0,
|
||||
// Other values omitted for now...
|
||||
} ze_result_t;
|
||||
@@ -19,11 +20,13 @@ typedef struct _zes_driver_handle_t *zes_driver_handle_t;
|
||||
typedef struct _zes_device_handle_t *zes_device_handle_t;
|
||||
typedef struct _zes_mem_handle_t *zes_mem_handle_t;
|
||||
|
||||
typedef enum _ze_structure_type_t {
|
||||
typedef enum _ze_structure_type_t
|
||||
{
|
||||
ZE_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} ze_structure_type_t;
|
||||
|
||||
typedef enum _zes_structure_type_t {
|
||||
typedef enum _zes_structure_type_t
|
||||
{
|
||||
ZES_STRUCTURE_TYPE_DEVICE_PROPERTIES = 0x1,
|
||||
ZES_STRUCTURE_TYPE_MEM_PROPERTIES = 0xb,
|
||||
ZES_STRUCTURE_TYPE_MEM_STATE = 0x1e,
|
||||
@@ -31,29 +34,35 @@ typedef enum _zes_structure_type_t {
|
||||
ZES_STRUCTURE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_structure_type_t;
|
||||
|
||||
typedef enum _zes_mem_type_t {
|
||||
typedef enum _zes_mem_type_t
|
||||
{
|
||||
ZES_MEM_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_type_t;
|
||||
|
||||
typedef enum _zes_mem_loc_t {
|
||||
typedef enum _zes_mem_loc_t
|
||||
{
|
||||
ZES_MEM_LOC_SYSTEM = 0,
|
||||
ZES_MEM_LOC_DEVICE = 1,
|
||||
ZES_MEM_LOC_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_loc_t;
|
||||
|
||||
typedef enum _zes_mem_health_t {
|
||||
typedef enum _zes_mem_health_t
|
||||
{
|
||||
ZES_MEM_HEALTH_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_mem_health_t;
|
||||
|
||||
typedef struct _ze_device_uuid_t {
|
||||
typedef struct _ze_device_uuid_t
|
||||
{
|
||||
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||
} ze_device_uuid_t;
|
||||
|
||||
typedef struct _zes_uuid_t {
|
||||
typedef struct _zes_uuid_t
|
||||
{
|
||||
uint8_t id[ZE_MAX_DEVICE_UUID_SIZE];
|
||||
} zes_uuid_t;
|
||||
|
||||
typedef enum _ze_device_type_t {
|
||||
typedef enum _ze_device_type_t
|
||||
{
|
||||
ZE_DEVICE_TYPE_GPU = 1,
|
||||
ZE_DEVICE_TYPE_CPU = 2,
|
||||
ZE_DEVICE_TYPE_FPGA = 3,
|
||||
@@ -62,7 +71,8 @@ typedef enum _ze_device_type_t {
|
||||
ZE_DEVICE_TYPE_FORCE_UINT32 = 0x7fffffff
|
||||
} ze_device_type_t;
|
||||
|
||||
typedef enum _zes_device_type_t {
|
||||
typedef enum _zes_device_type_t
|
||||
{
|
||||
ZES_DEVICE_TYPE_GPU = 1,
|
||||
ZES_DEVICE_TYPE_CPU = 2,
|
||||
ZES_DEVICE_TYPE_FPGA = 3,
|
||||
@@ -72,7 +82,8 @@ typedef enum _zes_device_type_t {
|
||||
} zes_device_type_t;
|
||||
|
||||
typedef uint32_t ze_device_property_flags_t;
|
||||
typedef enum _ze_device_property_flag_t {
|
||||
typedef enum _ze_device_property_flag_t
|
||||
{
|
||||
ZE_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||
ZE_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||
ZE_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||
@@ -81,7 +92,8 @@ typedef enum _ze_device_property_flag_t {
|
||||
} ze_device_property_flag_t;
|
||||
|
||||
typedef uint32_t zes_device_property_flags_t;
|
||||
typedef enum _zes_device_property_flag_t {
|
||||
typedef enum _zes_device_property_flag_t
|
||||
{
|
||||
ZES_DEVICE_PROPERTY_FLAG_INTEGRATED = ZE_BIT(0),
|
||||
ZES_DEVICE_PROPERTY_FLAG_SUBDEVICE = ZE_BIT(1),
|
||||
ZES_DEVICE_PROPERTY_FLAG_ECC = ZE_BIT(2),
|
||||
@@ -89,7 +101,8 @@ typedef enum _zes_device_property_flag_t {
|
||||
ZES_DEVICE_PROPERTY_FLAG_FORCE_UINT32 = 0x7fffffff
|
||||
} zes_device_property_flag_t;
|
||||
|
||||
typedef struct _ze_device_properties_t {
|
||||
typedef struct _ze_device_properties_t
|
||||
{
|
||||
ze_structure_type_t stype;
|
||||
void *pNext;
|
||||
ze_device_type_t type;
|
||||
@@ -113,7 +126,8 @@ typedef struct _ze_device_properties_t {
|
||||
char name[ZE_MAX_DEVICE_NAME];
|
||||
} ze_device_properties_t;
|
||||
|
||||
typedef struct _zes_device_properties_t {
|
||||
typedef struct _zes_device_properties_t
|
||||
{
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
ze_device_properties_t core;
|
||||
@@ -126,7 +140,8 @@ typedef struct _zes_device_properties_t {
|
||||
char driverVersion[ZES_STRING_PROPERTY_SIZE];
|
||||
} zes_device_properties_t;
|
||||
|
||||
typedef struct _zes_device_ext_properties_t {
|
||||
typedef struct _zes_device_ext_properties_t
|
||||
{
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
zes_uuid_t uuid;
|
||||
@@ -134,7 +149,8 @@ typedef struct _zes_device_ext_properties_t {
|
||||
zes_device_property_flags_t flags;
|
||||
} zes_device_ext_properties_t;
|
||||
|
||||
typedef struct _zes_mem_properties_t {
|
||||
typedef struct _zes_mem_properties_t
|
||||
{
|
||||
zes_structure_type_t stype;
|
||||
void *pNext;
|
||||
zes_mem_type_t type;
|
||||
@@ -146,7 +162,8 @@ typedef struct _zes_mem_properties_t {
|
||||
int32_t numChannels;
|
||||
} zes_mem_properties_t;
|
||||
|
||||
typedef struct _zes_mem_state_t {
|
||||
typedef struct _zes_mem_state_t
|
||||
{
|
||||
zes_structure_type_t stype;
|
||||
const void *pNext;
|
||||
zes_mem_health_t health;
|
||||
@@ -154,19 +171,10 @@ typedef struct _zes_mem_state_t {
|
||||
uint64_t size;
|
||||
} zes_mem_state_t;
|
||||
|
||||
typedef struct oneapi_handle {
|
||||
typedef struct oneapi_handle
|
||||
{
|
||||
void *handle;
|
||||
uint16_t verbose;
|
||||
|
||||
uint32_t num_drivers;
|
||||
zes_driver_handle_t *drivers;
|
||||
uint32_t *num_devices;
|
||||
zes_device_handle_t **devices;
|
||||
|
||||
// TODO Driver major, minor information
|
||||
// int driver_major;
|
||||
// int driver_minor;
|
||||
|
||||
ze_result_t (*zesInit)(int);
|
||||
ze_result_t (*zesDriverGet)(uint32_t *pCount, zes_driver_handle_t *phDrivers);
|
||||
ze_result_t (*zesDeviceGet)(zes_driver_handle_t hDriver, uint32_t *pCount,
|
||||
@@ -183,21 +191,21 @@ typedef struct oneapi_handle {
|
||||
|
||||
} oneapi_handle_t;
|
||||
|
||||
typedef struct oneapi_init_resp {
|
||||
typedef struct oneapi_init_resp
|
||||
{
|
||||
char *err; // If err is non-null handle is invalid
|
||||
int num_devices;
|
||||
oneapi_handle_t oh;
|
||||
} oneapi_init_resp_t;
|
||||
|
||||
typedef struct oneapi_version_resp {
|
||||
typedef struct oneapi_version_resp
|
||||
{
|
||||
ze_result_t status;
|
||||
char *str; // Contains version or error string if status != 0
|
||||
} oneapi_version_resp_t;
|
||||
|
||||
void oneapi_init(char *oneapi_lib_path, oneapi_init_resp_t *resp);
|
||||
void oneapi_check_vram(oneapi_handle_t h, int driver, int device,
|
||||
mem_info_t *resp);
|
||||
void oneapi_release(oneapi_handle_t h);
|
||||
int oneapi_get_device_count(oneapi_handle_t h, int driver);
|
||||
void oneapi_check_vram(oneapi_handle_t rh, mem_info_t *resp);
|
||||
|
||||
#endif // __GPU_INFO_INTEL_H__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -1,89 +0,0 @@
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
var CudartGlobs = []string{
|
||||
"/usr/local/cuda/lib64/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/x86_64-linux-gnu/libcudart.so*",
|
||||
"/usr/lib/wsl/lib/libcudart.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcudart.so*",
|
||||
"/opt/cuda/lib64/libcudart.so*",
|
||||
"/usr/local/cuda*/targets/aarch64-linux/lib/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/nvidia/current/libcudart.so*",
|
||||
"/usr/lib/aarch64-linux-gnu/libcudart.so*",
|
||||
"/usr/local/cuda/lib*/libcudart.so*",
|
||||
"/usr/lib*/libcudart.so*",
|
||||
"/usr/local/lib*/libcudart.so*",
|
||||
}
|
||||
|
||||
var NvmlGlobs = []string{}
|
||||
|
||||
var NvcudaGlobs = []string{
|
||||
"/usr/local/cuda*/targets/*/lib/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/nvidia/current/libcuda.so*",
|
||||
"/usr/lib/*-linux-gnu/libcuda.so*",
|
||||
"/usr/lib/wsl/lib/libcuda.so*",
|
||||
"/usr/lib/wsl/drivers/*/libcuda.so*",
|
||||
"/opt/cuda/lib*/libcuda.so*",
|
||||
"/usr/local/cuda/lib*/libcuda.so*",
|
||||
"/usr/lib*/libcuda.so*",
|
||||
"/usr/local/lib*/libcuda.so*",
|
||||
}
|
||||
|
||||
var OneapiGlobs = []string{
|
||||
"/usr/lib/x86_64-linux-gnu/libze_intel_gpu.so*",
|
||||
"/usr/lib*/libze_intel_gpu.so*",
|
||||
}
|
||||
|
||||
var CudartMgmtName = "libcudart.so*"
|
||||
var NvcudaMgmtName = "libcuda.so*"
|
||||
var NvmlMgmtName = "" // not currently wired on linux
|
||||
var OneapiMgmtName = "libze_intel_gpu.so"
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
var mem memInfo
|
||||
var total, available, free, buffers, cached uint64
|
||||
f, err := os.Open("/proc/meminfo")
|
||||
if err != nil {
|
||||
return mem, err
|
||||
}
|
||||
defer f.Close()
|
||||
s := bufio.NewScanner(f)
|
||||
for s.Scan() {
|
||||
line := s.Text()
|
||||
switch {
|
||||
case strings.HasPrefix(line, "MemTotal:"):
|
||||
_, err = fmt.Sscanf(line, "MemTotal:%d", &total)
|
||||
case strings.HasPrefix(line, "MemAvailable:"):
|
||||
_, err = fmt.Sscanf(line, "MemAvailable:%d", &available)
|
||||
case strings.HasPrefix(line, "MemFree:"):
|
||||
_, err = fmt.Sscanf(line, "MemFree:%d", &free)
|
||||
case strings.HasPrefix(line, "Buffers:"):
|
||||
_, err = fmt.Sscanf(line, "Buffers:%d", &buffers)
|
||||
case strings.HasPrefix(line, "Cached:"):
|
||||
_, err = fmt.Sscanf(line, "Cached:%d", &cached)
|
||||
default:
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
return mem, err
|
||||
}
|
||||
|
||||
if total > 0 && available > 0 {
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeMemory = available * format.KibiByte
|
||||
return mem, nil
|
||||
}
|
||||
}
|
||||
mem.TotalMemory = total * format.KibiByte
|
||||
mem.FreeMemory = (free + buffers + cached) * format.KibiByte
|
||||
return mem, nil
|
||||
}
|
||||
@@ -5,12 +5,11 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestBasicGetGPUInfo(t *testing.T) {
|
||||
info := GetGPUInfo()
|
||||
assert.NotEmpty(t, len(info))
|
||||
assert.Greater(t, len(info), 0)
|
||||
assert.Contains(t, "cuda rocm cpu metal", info[0].Library)
|
||||
if info[0].Library != "cpu" {
|
||||
assert.Greater(t, info[0].TotalMemory, uint64(0))
|
||||
@@ -20,7 +19,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
|
||||
|
||||
func TestCPUMemInfo(t *testing.T) {
|
||||
info, err := GetCPUMem()
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
t.Skip("CPU memory not populated on darwin")
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
package gpu
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type MEMORYSTATUSEX struct {
|
||||
length uint32
|
||||
MemoryLoad uint32
|
||||
TotalPhys uint64
|
||||
AvailPhys uint64
|
||||
TotalPageFile uint64
|
||||
AvailPageFile uint64
|
||||
TotalVirtual uint64
|
||||
AvailVirtual uint64
|
||||
AvailExtendedVirtual uint64
|
||||
}
|
||||
|
||||
var (
|
||||
k32 = syscall.NewLazyDLL("kernel32.dll")
|
||||
globalMemoryStatusExProc = k32.NewProc("GlobalMemoryStatusEx")
|
||||
sizeofMemoryStatusEx = uint32(unsafe.Sizeof(MEMORYSTATUSEX{}))
|
||||
)
|
||||
|
||||
var CudartGlobs = []string{
|
||||
"c:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v*\\bin\\cudart64_*.dll",
|
||||
}
|
||||
|
||||
var NvmlGlobs = []string{
|
||||
"c:\\Windows\\System32\\nvml.dll",
|
||||
}
|
||||
|
||||
var NvcudaGlobs = []string{
|
||||
"c:\\windows\\system*\\nvcuda.dll",
|
||||
}
|
||||
|
||||
var OneapiGlobs = []string{
|
||||
"c:\\Windows\\System32\\DriverStore\\FileRepository\\*\\ze_intel_gpu64.dll",
|
||||
}
|
||||
|
||||
var CudartMgmtName = "cudart64_*.dll"
|
||||
var NvcudaMgmtName = "nvcuda.dll"
|
||||
var NvmlMgmtName = "nvml.dll"
|
||||
var OneapiMgmtName = "ze_intel_gpu64.dll"
|
||||
|
||||
func GetCPUMem() (memInfo, error) {
|
||||
memStatus := MEMORYSTATUSEX{length: sizeofMemoryStatusEx}
|
||||
r1, _, err := globalMemoryStatusExProc.Call(uintptr(unsafe.Pointer(&memStatus)))
|
||||
if r1 == 0 {
|
||||
return memInfo{}, fmt.Errorf("GlobalMemoryStatusEx failed: %w", err)
|
||||
}
|
||||
return memInfo{TotalMemory: memStatus.TotalPhys, FreeMemory: memStatus.AvailPhys}, nil
|
||||
}
|
||||
56
gpu/types.go
56
gpu/types.go
@@ -18,7 +18,7 @@ type GpuInfo struct {
|
||||
Library string `json:"library,omitempty"`
|
||||
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant CPUCapability `json:"variant"`
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
@@ -26,9 +26,6 @@ type GpuInfo struct {
|
||||
// Any extra PATH/LD_LIBRARY_PATH dependencies required for the Library to operate properly
|
||||
DependencyPath string `json:"lib_path,omitempty"`
|
||||
|
||||
// Extra environment variables specific to the GPU as list of [key,value]
|
||||
EnvWorkarounds [][2]string `json:"envs,omitempty"`
|
||||
|
||||
// GPU information
|
||||
ID string `json:"gpu_id"` // string to use for selection of this specific GPU
|
||||
Name string `json:"name"` // user friendly name if available
|
||||
@@ -41,30 +38,6 @@ type GpuInfo struct {
|
||||
// TODO other performance capability info to help in scheduling decisions
|
||||
}
|
||||
|
||||
type CPUInfo struct {
|
||||
GpuInfo
|
||||
}
|
||||
|
||||
type CudaGPUInfo struct {
|
||||
GpuInfo
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type CudaGPUInfoList []CudaGPUInfo
|
||||
|
||||
type RocmGPUInfo struct {
|
||||
GpuInfo
|
||||
usedFilepath string //nolint:unused,nolintlint
|
||||
index int //nolint:unused,nolintlint
|
||||
}
|
||||
type RocmGPUInfoList []RocmGPUInfo
|
||||
|
||||
type OneapiGPUInfo struct {
|
||||
GpuInfo
|
||||
driverIndex int //nolint:unused,nolintlint
|
||||
gpuIndex int //nolint:unused,nolintlint
|
||||
}
|
||||
type OneapiGPUInfoList []OneapiGPUInfo
|
||||
|
||||
type GpuInfoList []GpuInfo
|
||||
|
||||
// Split up the set of gpu info's by Library and variant
|
||||
@@ -74,8 +47,8 @@ func (l GpuInfoList) ByLibrary() []GpuInfoList {
|
||||
for _, info := range l {
|
||||
found := false
|
||||
requested := info.Library
|
||||
if info.Variant != CPUCapabilityNone {
|
||||
requested += "_" + info.Variant.String()
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
for i, lib := range libs {
|
||||
if lib == requested {
|
||||
@@ -113,26 +86,3 @@ type ByFreeMemory []GpuInfo
|
||||
func (a ByFreeMemory) Len() int { return len(a) }
|
||||
func (a ByFreeMemory) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
|
||||
func (a ByFreeMemory) Less(i, j int) bool { return a[i].FreeMemory < a[j].FreeMemory }
|
||||
|
||||
type CPUCapability uint32
|
||||
|
||||
// Override at build time when building base GPU runners
|
||||
var GPURunnerCPUCapability = CPUCapabilityAVX
|
||||
|
||||
const (
|
||||
CPUCapabilityNone CPUCapability = iota
|
||||
CPUCapabilityAVX
|
||||
CPUCapabilityAVX2
|
||||
// TODO AVX512
|
||||
)
|
||||
|
||||
func (c CPUCapability) String() string {
|
||||
switch c {
|
||||
case CPUCapabilityAVX:
|
||||
return "avx"
|
||||
case CPUCapabilityAVX2:
|
||||
return "avx2"
|
||||
default:
|
||||
return "no vector extensions"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,19 +19,17 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
var (
|
||||
req = [2]api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "tinydolphin",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
@@ -40,64 +38,42 @@ func TestMultiModelConcurrency(t *testing.T) {
|
||||
}
|
||||
resp = [2][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims", "british"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
}
|
||||
)
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(req))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*240)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second*120)
|
||||
defer cancel()
|
||||
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
require.NoError(t, PullIfMissing(ctx, client, req[i].Model))
|
||||
}
|
||||
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 60*time.Second, 10*time.Second)
|
||||
GenerateTestHelper(ctx, t, req[i], resp[i])
|
||||
}(i)
|
||||
}
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
func TestIntegrationConcurrentPredictOrcaMini(t *testing.T) {
|
||||
req, resp := GenerateRequests()
|
||||
reqLimit := len(req)
|
||||
iterLimit := 5
|
||||
|
||||
vram := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if vram != "" {
|
||||
max, err := strconv.ParseUint(vram, 10, 64)
|
||||
require.NoError(t, err)
|
||||
// Don't hammer on small VRAM cards...
|
||||
if max < 4*1024*1024*1024 {
|
||||
reqLimit = min(reqLimit, 2)
|
||||
iterLimit = 2
|
||||
}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 9*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) // GTX 750 2G card takes ~9 minutes
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
|
||||
req, resp := GenerateRequests()
|
||||
// Get the server running (if applicable) warm the model up with a single initial request
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 10*time.Second)
|
||||
DoGenerate(ctx, t, client, req[0], resp[0], 60*time.Second, 5*time.Second)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(reqLimit)
|
||||
for i := 0; i < reqLimit; i++ {
|
||||
wg.Add(len(req))
|
||||
for i := 0; i < len(req); i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
for j := 0; j < iterLimit; j++ {
|
||||
for j := 0; j < 5; j++ {
|
||||
slog.Info("Starting", "req", i, "iter", j)
|
||||
// On slower GPUs it can take a while to process the concurrent requests
|
||||
// On slower GPUs it can take a while to process the 4 concurrent requests
|
||||
// so we allow a much longer initial timeout
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 120*time.Second, 20*time.Second)
|
||||
DoGenerate(ctx, t, client, req[i], resp[i], 90*time.Second, 5*time.Second)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
@@ -245,23 +221,5 @@ func TestMultiModelStress(t *testing.T) {
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
time.Sleep(2 * time.Second)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
models, err := client.ListRunning(ctx)
|
||||
if err != nil {
|
||||
slog.Warn("failed to list running models", "error", err)
|
||||
continue
|
||||
}
|
||||
for _, m := range models.Models {
|
||||
slog.Info("loaded model snapshot", "model", m)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
}
|
||||
|
||||
@@ -11,8 +11,7 @@ import (
|
||||
)
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
// Longer needed for small footprint GPUs
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 6*time.Minute)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
|
||||
@@ -32,11 +32,7 @@ func TestIntegrationMultimodal(t *testing.T) {
|
||||
resp := "the ollam"
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
|
||||
defer cancel()
|
||||
client, _, cleanup := InitServerConnection(ctx, t)
|
||||
defer cleanup()
|
||||
require.NoError(t, PullIfMissing(ctx, client, req.Model))
|
||||
// llava models on CPU can be quite slow to start,
|
||||
DoGenerate(ctx, t, client, req, []string{resp}, 120*time.Second, 30*time.Second)
|
||||
GenerateTestHelper(ctx, t, req, []string{resp})
|
||||
}
|
||||
|
||||
const imageEncoding = `iVBORw0KGgoAAAANSUhEUgAAANIAAAB4CAYAAACHHqzKAAAAAXNSR0IArs4c6QAAAIRlWElmTU0AKgAAAAgABQESAAMAAAABAAEAAAEaAAUAAAABAAAASgEb
|
||||
|
||||
@@ -140,7 +140,7 @@ func PullIfMissing(ctx context.Context, client *api.Client, modelName string) er
|
||||
|
||||
showCtx, cancel := context.WithDeadlineCause(
|
||||
ctx,
|
||||
time.Now().Add(10*time.Second),
|
||||
time.Now().Add(5*time.Second),
|
||||
fmt.Errorf("show for existing model %s took too long", modelName),
|
||||
)
|
||||
defer cancel()
|
||||
@@ -287,46 +287,41 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
return []api.GenerateRequest{
|
||||
{
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the ocean blue?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "why is the color of dirt brown?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of the us thanksgiving holiday?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the origin of independence day?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
},
|
||||
}, {
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
KeepAlive: &api.Duration{Duration: 10 * time.Second},
|
||||
Model: "orca-mini",
|
||||
Prompt: "what is the composition of air?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
@@ -336,7 +331,7 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||
[][]string{
|
||||
[]string{"sunlight"},
|
||||
[]string{"soil", "organic", "earth", "black", "tan"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims", "british"},
|
||||
[]string{"england", "english", "massachusetts", "pilgrims"},
|
||||
[]string{"fourth", "july", "declaration", "independence"},
|
||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||
}
|
||||
|
||||
25
llm/ext_server/server.cpp
vendored
25
llm/ext_server/server.cpp
vendored
@@ -359,6 +359,7 @@ struct llama_server_context
|
||||
|
||||
// slots / clients
|
||||
std::vector<server_slot> slots;
|
||||
json default_generation_settings_for_props;
|
||||
|
||||
llama_server_queue queue_tasks;
|
||||
llama_server_response queue_results;
|
||||
@@ -482,6 +483,9 @@ struct llama_server_context
|
||||
slots.push_back(slot);
|
||||
}
|
||||
|
||||
default_generation_settings_for_props = get_formated_generation(slots.front());
|
||||
default_generation_settings_for_props["seed"] = -1;
|
||||
|
||||
batch = llama_batch_init(n_ctx, 0, params.n_parallel);
|
||||
}
|
||||
|
||||
@@ -580,7 +584,7 @@ struct llama_server_context
|
||||
slot->sparams.mirostat_eta = json_value(data, "mirostat_eta", default_sparams.mirostat_eta);
|
||||
slot->sparams.penalize_nl = json_value(data, "penalize_nl", default_sparams.penalize_nl);
|
||||
slot->params.n_keep = json_value(data, "n_keep", slot->params.n_keep);
|
||||
slot->sparams.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->params.seed = json_value(data, "seed", default_params.seed);
|
||||
slot->sparams.grammar = json_value(data, "grammar", default_sparams.grammar);
|
||||
slot->sparams.n_probs = json_value(data, "n_probs", default_sparams.n_probs);
|
||||
slot->sparams.min_keep = json_value(data, "min_keep", default_sparams.min_keep);
|
||||
@@ -807,6 +811,7 @@ struct llama_server_context
|
||||
llama_sampling_free(slot->ctx_sampling);
|
||||
}
|
||||
slot->ctx_sampling = llama_sampling_init(slot->sparams);
|
||||
llama_set_rng_seed(ctx, slot->params.seed);
|
||||
slot->command = LOAD_PROMPT;
|
||||
|
||||
all_slots_are_idle = false;
|
||||
@@ -830,7 +835,7 @@ struct llama_server_context
|
||||
system_tokens.clear();
|
||||
|
||||
if (!system_prompt.empty()) {
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, true);
|
||||
system_tokens = ::llama_tokenize(ctx, system_prompt, add_bos_token);
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
||||
@@ -1651,7 +1656,7 @@ struct llama_server_context
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_genereration = 0;
|
||||
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty()); // add BOS if there isn't system prompt
|
||||
prompt_tokens = tokenize(slot.prompt, system_prompt.empty() && add_bos_token); // add BOS if there isn't system prompt
|
||||
|
||||
slot.n_prompt_tokens = prompt_tokens.size();
|
||||
|
||||
@@ -2335,9 +2340,9 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#ifndef GGML_USE_CUDA
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without CUDA. Setting the split mode has no effect.\n");
|
||||
#endif // GGML_USE_CUDA
|
||||
#ifndef GGML_USE_CUBLAS
|
||||
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. Setting the split mode has no effect.\n");
|
||||
#endif // GGML_USE_CUBLAS
|
||||
}
|
||||
else if (arg == "--tensor-split" || arg == "-ts")
|
||||
{
|
||||
@@ -2346,7 +2351,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
|
||||
std::string arg_next = argv[i];
|
||||
|
||||
// split string by , and /
|
||||
@@ -2367,8 +2372,8 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||
}
|
||||
}
|
||||
#else
|
||||
LOG_WARNING("llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n", {});
|
||||
#endif // GGML_USE_CUDA
|
||||
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n", {});
|
||||
#endif // GGML_USE_CUBLAS
|
||||
}
|
||||
else if (arg == "--main-gpu" || arg == "-mg")
|
||||
{
|
||||
@@ -2377,7 +2382,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams, g
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
#if defined(GGML_USE_CUDA) || defined(GGML_USE_SYCL)
|
||||
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_SYCL)
|
||||
params.main_gpu = std::stoi(argv[i]);
|
||||
#else
|
||||
LOG_WARNING("llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.", {});
|
||||
|
||||
@@ -211,7 +211,7 @@ if [ -z "${ONEAPI_ROOT}" ]; then
|
||||
ONEAPI_ROOT=/opt/intel/oneapi
|
||||
fi
|
||||
|
||||
if [ -z "${OLLAMA_SKIP_ONEAPI_GENERATE}" -a -d "${ONEAPI_ROOT}" ]; then
|
||||
if [ -d "${ONEAPI_ROOT}" ]; then
|
||||
echo "OneAPI libraries detected - building dynamic OneAPI library"
|
||||
init_vars
|
||||
source ${ONEAPI_ROOT}/setvars.sh --force # set up environment variables for oneAPI
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
#!powershell
|
||||
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
function amdGPUs {
|
||||
if ($env:AMDGPU_TARGETS) {
|
||||
return $env:AMDGPU_TARGETS
|
||||
@@ -83,9 +85,9 @@ function init_vars {
|
||||
function git_module_setup {
|
||||
# TODO add flags to skip the init/patch logic to make it easier to mod llama.cpp code in-repo
|
||||
& git submodule init
|
||||
if ($LASTEXITCODE -ne 0) { throw($LASTEXITCODE)}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
& git submodule update --force "${script:llamacppDir}"
|
||||
if ($LASTEXITCODE -ne 0) { throw($LASTEXITCODE)}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function apply_patches {
|
||||
@@ -119,15 +121,10 @@ function build {
|
||||
write-host "generating config with: cmake -S ${script:llamacppDir} -B $script:buildDir $script:cmakeDefs"
|
||||
& cmake --version
|
||||
& cmake -S "${script:llamacppDir}" -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { throw($LASTEXITCODE)}
|
||||
if ($cmakeDefs -contains "-G") {
|
||||
$extra=@("-j8")
|
||||
} else {
|
||||
$extra= @("--", "/p:CL_MPcount=8")
|
||||
}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config $($script:cmakeTargets | ForEach-Object { `"--target`", $_ }) $extra"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ }) $extra
|
||||
if ($LASTEXITCODE -ne 0) { write-host "cmake build exit status $LASTEXITCODE"; throw($LASTEXITCODE)}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config $($script:cmakeTargets | ForEach-Object { `"--target`", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
# Rearrange output to be consistent between different generators
|
||||
if ($null -ne ${script:config} -And (test-path -path "${script:buildDir}/bin/${script:config}" ) ) {
|
||||
mv -force "${script:buildDir}/bin/${script:config}/*" "${script:buildDir}/bin/"
|
||||
@@ -141,7 +138,7 @@ function sign {
|
||||
foreach ($file in @(get-childitem "${script:buildDir}/bin/*.exe") + @(get-childitem "${script:buildDir}/bin/*.dll")){
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc "${env:KEY_CONTAINER}" $file
|
||||
if ($LASTEXITCODE -ne 0) { throw($LASTEXITCODE)}
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -216,13 +213,7 @@ function build_static() {
|
||||
}
|
||||
}
|
||||
|
||||
function build_cpu() {
|
||||
if ($script:ARCH -eq "arm64") {
|
||||
$gen_arch = "ARM64"
|
||||
} else { # amd64
|
||||
$gen_arch = "x64"
|
||||
}
|
||||
|
||||
function build_cpu($gen_arch) {
|
||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
@@ -279,15 +270,7 @@ function build_cuda() {
|
||||
init_vars
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:distDir="$script:DIST_BASE\cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @(
|
||||
"-A", "x64",
|
||||
"-DLLAMA_CUDA=ON",
|
||||
"-DLLAMA_AVX=on",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR",
|
||||
"-DCMAKE_CUDA_FLAGS=-t8"
|
||||
"-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}"
|
||||
)
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
if ($null -ne $env:OLLAMA_CUSTOM_CUDA_DEFS) {
|
||||
write-host "OLLAMA_CUSTOM_CUDA_DEFS=`"${env:OLLAMA_CUSTOM_CUDA_DEFS}`""
|
||||
$script:cmakeDefs +=@("${env:OLLAMA_CUSTOM_CUDA_DEFS}")
|
||||
@@ -307,7 +290,7 @@ function build_cuda() {
|
||||
}
|
||||
|
||||
function build_oneapi() {
|
||||
if ((-not "${env:OLLAMA_SKIP_ONEAPI_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
|
||||
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${env:ONEAPI_ROOT}")) {
|
||||
# Get oneAPI version
|
||||
$script:ONEAPI_VERSION = icpx --version
|
||||
$script:ONEAPI_VERSION = [regex]::Match($script:ONEAPI_VERSION, '(?<=oneAPI DPC\+\+/C\+\+ Compiler )(?<version>\d+\.\d+\.\d+)').Value
|
||||
@@ -408,29 +391,16 @@ init_vars
|
||||
if ($($args.count) -eq 0) {
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
$tasks = @("build_static", "build_cpu")
|
||||
$jobs = @()
|
||||
if ($script:ARCH -ne "arm64") {
|
||||
$tasks += $("build_cpu_avx", "build_cpu_avx2", "build_cuda", "build_oneapi", "build_rocm")
|
||||
}
|
||||
foreach ($t in $tasks) {
|
||||
$jobs += @(Start-ThreadJob -ThrottleLimit 12 -FilePath .\gen_windows.ps1 -ArgumentList $t -Name $t)
|
||||
}
|
||||
get-job
|
||||
foreach ($job in $jobs) {
|
||||
write-host "----" $job.Name output follows
|
||||
receive-job -wait -job $job
|
||||
write-host "----" $job.Name $job.State
|
||||
write-host ""
|
||||
if ($job.State -contains 'Failed') {
|
||||
cleanup
|
||||
write-host "Terminating remaining jobs (this takes a while, you can ^C)"
|
||||
# TODO find some way to kill the spawned cmake processes faster
|
||||
remove-job -force -job $jobs
|
||||
exit(-1)
|
||||
}
|
||||
get-job
|
||||
build_static
|
||||
if ($script:ARCH -eq "arm64") {
|
||||
build_cpu("ARM64")
|
||||
} else { # amd64
|
||||
build_cpu("x64")
|
||||
build_cpu_avx
|
||||
build_cpu_avx2
|
||||
build_cuda
|
||||
build_oneapi
|
||||
build_rocm
|
||||
}
|
||||
|
||||
cleanup
|
||||
|
||||
@@ -81,11 +81,6 @@ func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) ChatTemplate() string {
|
||||
s, _ := kv["tokenizer.chat_template"].(string)
|
||||
return s
|
||||
}
|
||||
|
||||
type Tensors []*Tensor
|
||||
|
||||
func (ts Tensors) Layers() map[string]Layer {
|
||||
@@ -307,7 +302,6 @@ func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload ui
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
// 4*batch*(4+6*embedding+context*(2*heads)+llm.KV().GQA()),
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
26
llm/gguf.go
26
llm/gguf.go
@@ -592,8 +592,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
var dims int
|
||||
for cnt := range len(tensor.Shape) {
|
||||
dims := 0
|
||||
for cnt := 0; cnt < len(tensor.Shape); cnt++ {
|
||||
if tensor.Shape[cnt] > 0 {
|
||||
dims++
|
||||
}
|
||||
@@ -603,8 +603,8 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
for i := range dims {
|
||||
if err := binary.Write(ws, llm.ByteOrder, tensor.Shape[dims-1-i]); err != nil {
|
||||
for i := 0; i < dims; i++ {
|
||||
if err := binary.Write(ws, llm.ByteOrder, uint64(tensor.Shape[dims-1-i])); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -618,8 +618,22 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
}
|
||||
}
|
||||
|
||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var alignment int64 = 32
|
||||
padding := llm.padding(offset, alignment)
|
||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, tensor := range tensors {
|
||||
if _, err := tensor.WriteTo(ws); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
offset, err := ws.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -629,10 +643,6 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
if err := binary.Write(ws, llm.ByteOrder, bytes.Repeat([]byte{0}, int(padding))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := tensor.WriteTo(ws); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
|
||||
263
llm/memory.go
263
llm/memory.go
@@ -1,13 +1,13 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
// This algorithm looks for a complete fit to determine if we need to unload other models
|
||||
@@ -16,8 +16,7 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
|
||||
var estimatedVRAM uint64
|
||||
for _, gpus := range allGpus.ByLibrary() {
|
||||
var layerCount int
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize
|
||||
layerCount, estimatedVRAM, _ = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
if opts.NumGPU < 0 {
|
||||
if layerCount > 0 && layerCount >= int(ggml.KV().BlockCount()+1) {
|
||||
return true, estimatedVRAM
|
||||
@@ -31,64 +30,24 @@ func PredictServerFit(allGpus gpu.GpuInfoList, ggml *GGML, adapters, projectors
|
||||
return false, estimatedVRAM
|
||||
}
|
||||
|
||||
type MemoryEstimate struct {
|
||||
// How many layers we predict we can load
|
||||
Layers int
|
||||
|
||||
// The size of the graph which occupies the main GPU
|
||||
Graph uint64
|
||||
|
||||
// How much VRAM will be allocated given the number of layers we predict
|
||||
VRAMSize uint64
|
||||
|
||||
// The total size of the model if loaded into VRAM. If all layers are loaded, VRAMSize == TotalSize
|
||||
TotalSize uint64
|
||||
|
||||
// For multi-GPU scenarios, this provides the tensor split parameter
|
||||
TensorSplit string
|
||||
|
||||
// For multi-GPU scenarios, this is the size in bytes per GPU
|
||||
GPUSizes []uint64
|
||||
}
|
||||
|
||||
// Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size
|
||||
// The GPUs provided must all be the same Library
|
||||
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) MemoryEstimate {
|
||||
// Graph size for a partial offload, applies to all GPUs
|
||||
var graphPartialOffload uint64
|
||||
|
||||
// Graph size when all layers are offloaded, applies to all GPUs
|
||||
var graphFullOffload uint64
|
||||
|
||||
// Final graph offload once we know full or partial
|
||||
var graphOffload uint64
|
||||
|
||||
// Projectors loaded into GPU0 only
|
||||
var projectorSize uint64
|
||||
|
||||
// Conditional output size on GPU 0
|
||||
var memoryLayerOutput uint64
|
||||
|
||||
// The sizes of a layer
|
||||
var layerSize uint64
|
||||
|
||||
// The sum of all the layer sizes (just for logging)
|
||||
var memoryWeights uint64
|
||||
|
||||
// True if all the layers are loaded
|
||||
var fullyLoaded bool
|
||||
|
||||
// Overflow that didn't fit into the GPU
|
||||
var overflow uint64
|
||||
|
||||
availableList := make([]string, len(gpus))
|
||||
for i, gpu := range gpus {
|
||||
availableList[i] = format.HumanBytes2(gpu.FreeMemory)
|
||||
func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts api.Options) (int, uint64, uint64) {
|
||||
var memoryAvailable uint64
|
||||
for _, info := range gpus {
|
||||
memoryAvailable += info.FreeMemory
|
||||
}
|
||||
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", availableList)
|
||||
if envconfig.MaxVRAM > 0 {
|
||||
memoryAvailable = envconfig.MaxVRAM
|
||||
}
|
||||
|
||||
slog.Debug("evaluating", "library", gpus[0].Library, "gpu_count", len(gpus), "available", format.HumanBytes2(memoryAvailable))
|
||||
|
||||
// TODO - this is probably wrong, first GPU vs secondaries will have different overheads
|
||||
memoryMinimum := gpus[0].MinimumMemory
|
||||
|
||||
for _, projector := range projectors {
|
||||
projectorSize += projectorMemoryRequirements(projector)
|
||||
memoryMinimum += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
@@ -97,160 +56,79 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
layers := ggml.Tensors().Layers()
|
||||
// add one layer worth of memory as a buffer
|
||||
if blk0, ok := layers["blk.0"]; ok {
|
||||
layerSize = blk0.size()
|
||||
} else {
|
||||
slog.Warn("model missing blk.0 layer size")
|
||||
memoryMinimum += blk0.size()
|
||||
}
|
||||
|
||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
layerSize += kv / ggml.KV().BlockCount()
|
||||
|
||||
graphPartialOffload, graphFullOffload = ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
}
|
||||
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
graphFullOffload *= uint64(len(gpus))
|
||||
graphPartialOffload *= uint64(len(gpus))
|
||||
|
||||
// on metal there's no partial offload overhead
|
||||
if gpus[0].Library == "metal" {
|
||||
graphPartialOffload = graphFullOffload
|
||||
} else if len(gpus) > 1 {
|
||||
// multigpu should always use the partial graph size
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
|
||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
|
||||
var memoryLayerOutput uint64
|
||||
if layer, ok := layers["output_norm"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
}
|
||||
|
||||
if layer, ok := layers["output"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
} else if layer, ok := layers["token_embd"]; ok {
|
||||
memoryLayerOutput += layer.size()
|
||||
}
|
||||
|
||||
// Output layer handled at the end if we have space
|
||||
gpuZeroOverhead := projectorSize
|
||||
if gpus[0].Library == "metal" && opts.UseMMap {
|
||||
// memory is preallocated for output tensors
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
memoryRequiredPartial += memoryLayerOutput
|
||||
}
|
||||
|
||||
// Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer
|
||||
var layerCount int
|
||||
layerCounts := make([]int, len(gpus))
|
||||
gpuAllocations := make([]uint64, len(gpus))
|
||||
type gs struct {
|
||||
i int
|
||||
g *gpu.GpuInfo
|
||||
}
|
||||
gpusWithSpace := []gs{}
|
||||
for i := range gpus {
|
||||
var gzo uint64
|
||||
if len(gpusWithSpace) == 0 {
|
||||
gzo = gpuZeroOverhead
|
||||
}
|
||||
// Only include GPUs that can fit the graph, gpu minimum, the layer buffer and at least more layer
|
||||
if gpus[i].FreeMemory < gzo+max(graphPartialOffload, graphFullOffload)+gpus[i].MinimumMemory+2*layerSize {
|
||||
slog.Debug("gpu has too little memory to allocate any layers", "gpu", gpus[i])
|
||||
continue
|
||||
}
|
||||
gpusWithSpace = append(gpusWithSpace, gs{i, &gpus[i]})
|
||||
gpuAllocations[i] += gpus[i].MinimumMemory + layerSize // We hold off on graph until we know partial vs. full
|
||||
}
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
if blk, ok := layers[fmt.Sprintf("blk.%d", i)]; ok {
|
||||
memoryLayer := blk.size()
|
||||
|
||||
var gpuZeroID int
|
||||
if len(gpusWithSpace) > 0 {
|
||||
gpuZeroID = gpusWithSpace[0].i
|
||||
gpuAllocations[gpuZeroID] += gpuZeroOverhead
|
||||
}
|
||||
// KV is proportional to the number of layers
|
||||
memoryLayer += kv / ggml.KV().BlockCount()
|
||||
|
||||
// For all the layers, find where they can fit on the GPU(s)
|
||||
for i := range int(ggml.KV().BlockCount()) {
|
||||
memoryWeights += layerSize
|
||||
|
||||
if opts.NumGPU >= 0 && layerCount >= opts.NumGPU {
|
||||
// Stop allocating on GPU(s) once we hit the users target NumGPU
|
||||
continue
|
||||
}
|
||||
|
||||
// distribute the layers across the GPU(s) that have space
|
||||
for j := len(gpusWithSpace); j > 0; j-- {
|
||||
g := gpusWithSpace[i%j]
|
||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > used+layerSize {
|
||||
gpuAllocations[g.i] += layerSize
|
||||
layerCounts[g.i]++
|
||||
memoryRequiredTotal += memoryLayer
|
||||
if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredPartial+memoryLayer) {
|
||||
memoryRequiredPartial += memoryLayer
|
||||
layerCount++
|
||||
break
|
||||
} else {
|
||||
gpusWithSpace = append(gpusWithSpace[:i%j], gpusWithSpace[i%j+1:]...)
|
||||
}
|
||||
}
|
||||
}
|
||||
if layerCount >= int(ggml.KV().BlockCount()) {
|
||||
fullyLoaded = true
|
||||
} else {
|
||||
for i := layerCount; i < int(ggml.KV().BlockCount()); i++ {
|
||||
overflow += layerSize
|
||||
}
|
||||
|
||||
if gpus[0].Library != "metal" || !opts.UseMMap {
|
||||
// memory was not preallocated for output tensors
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
}
|
||||
|
||||
// Determine if we need to consider output then find where it fits
|
||||
if memoryLayerOutput > 0 && (opts.NumGPU < 0 || layerCount < opts.NumGPU) {
|
||||
for j := len(gpusWithSpace); j > 0; j-- {
|
||||
g := gpusWithSpace[layerCount%j]
|
||||
used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload)
|
||||
if g.g.FreeMemory > used+memoryLayerOutput {
|
||||
gpuAllocations[g.i] += memoryLayerOutput
|
||||
layerCounts[g.i]++
|
||||
layerCount++
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if layerCount < int(ggml.KV().BlockCount())+1 {
|
||||
fullyLoaded = false
|
||||
overflow += memoryLayerOutput
|
||||
}
|
||||
if (opts.NumGPU >= 0 && layerCount+1 <= opts.NumGPU) || (opts.NumGPU < 0 && memoryAvailable > memoryRequiredTotal) {
|
||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||
memoryRequiredPartial = memoryRequiredTotal
|
||||
}
|
||||
|
||||
// Add the applicable (full or partial) graph allocations
|
||||
for i := range gpus {
|
||||
if layerCounts[i] <= 0 {
|
||||
continue
|
||||
}
|
||||
if fullyLoaded {
|
||||
gpuAllocations[i] += graphFullOffload
|
||||
} else {
|
||||
gpuAllocations[i] += graphPartialOffload
|
||||
}
|
||||
}
|
||||
if fullyLoaded {
|
||||
graphOffload = graphFullOffload
|
||||
} else {
|
||||
graphOffload = graphPartialOffload
|
||||
}
|
||||
|
||||
// Summaries for the log
|
||||
var memoryRequiredPartial, memoryRequiredTotal uint64
|
||||
for i := range gpuAllocations {
|
||||
memoryRequiredPartial += gpuAllocations[i]
|
||||
}
|
||||
memoryRequiredTotal = memoryRequiredPartial + overflow
|
||||
|
||||
tensorSplit := ""
|
||||
if len(gpus) > 1 {
|
||||
splits := make([]string, len(gpus))
|
||||
for i, count := range layerCounts {
|
||||
splits[i] = strconv.Itoa(count)
|
||||
}
|
||||
tensorSplit = strings.Join(splits, ",")
|
||||
}
|
||||
allocationsList := []string{}
|
||||
for _, a := range gpuAllocations {
|
||||
allocationsList = append(allocationsList, format.HumanBytes2(a))
|
||||
}
|
||||
memoryWeights := memoryRequiredTotal - memoryMinimum - graphFullOffload - kv
|
||||
|
||||
slog.Info(
|
||||
"offload to gpu",
|
||||
@@ -258,17 +136,13 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
"layers",
|
||||
// requested number of layers to offload
|
||||
"requested", opts.NumGPU,
|
||||
// The number of layers the model has (including output)
|
||||
"model", int(ggml.KV().BlockCount())+1,
|
||||
// estimated number of layers that can be offloaded
|
||||
"offload", layerCount,
|
||||
// multi-gpu split for tesnors
|
||||
"split", tensorSplit,
|
||||
"real", layerCount,
|
||||
),
|
||||
slog.Group(
|
||||
"memory",
|
||||
// memory available by GPU for offloading
|
||||
"available", availableList,
|
||||
// memory available for offloading
|
||||
"available", format.HumanBytes2(memoryAvailable),
|
||||
slog.Group(
|
||||
"required",
|
||||
// memory required for full offloading
|
||||
@@ -277,8 +151,6 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
"partial", format.HumanBytes2(memoryRequiredPartial),
|
||||
// memory of KV cache
|
||||
"kv", format.HumanBytes2(kv),
|
||||
// Allocations across the GPUs
|
||||
"allocations", allocationsList,
|
||||
),
|
||||
slog.Group(
|
||||
"weights",
|
||||
@@ -299,31 +171,12 @@ func EstimateGPULayers(gpus []gpu.GpuInfo, ggml *GGML, projectors []string, opts
|
||||
),
|
||||
)
|
||||
if gpus[0].Library == "cpu" {
|
||||
return MemoryEstimate{
|
||||
Layers: 0,
|
||||
Graph: 0,
|
||||
VRAMSize: 0,
|
||||
TotalSize: memoryRequiredTotal,
|
||||
GPUSizes: []uint64{},
|
||||
}
|
||||
return 0, 0, memoryRequiredTotal
|
||||
}
|
||||
if layerCount == 0 {
|
||||
if memoryRequiredPartial > memoryAvailable {
|
||||
slog.Debug("insufficient VRAM to load any model layers")
|
||||
return MemoryEstimate{
|
||||
Layers: 0,
|
||||
Graph: 0,
|
||||
VRAMSize: 0,
|
||||
TotalSize: memoryRequiredTotal,
|
||||
GPUSizes: []uint64{},
|
||||
}
|
||||
return 0, 0, memoryRequiredTotal
|
||||
}
|
||||
|
||||
return MemoryEstimate{
|
||||
Layers: layerCount,
|
||||
Graph: graphOffload,
|
||||
VRAMSize: memoryRequiredPartial,
|
||||
TotalSize: memoryRequiredTotal,
|
||||
TensorSplit: tensorSplit,
|
||||
GPUSizes: gpuAllocations,
|
||||
}
|
||||
return layerCount, memoryRequiredPartial, memoryRequiredTotal
|
||||
}
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestEstimateGPULayers(t *testing.T) {
|
||||
envconfig.Debug = true
|
||||
modelName := "dummy"
|
||||
f, err := os.CreateTemp(t.TempDir(), modelName)
|
||||
require.NoError(t, err)
|
||||
defer f.Close()
|
||||
gguf := NewGGUFV3(binary.LittleEndian)
|
||||
inputLayerCount := 5
|
||||
tensors := []Tensor{
|
||||
{Name: "blk.0.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
{Name: "blk.1.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
{Name: "blk.2.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
{Name: "blk.3.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
{Name: "blk.4.attn.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
{Name: "output.weight", Kind: uint32(0), Offset: uint64(0), Shape: []uint64{1, 1, 1, 1}, WriterTo: &bytes.Reader{}},
|
||||
}
|
||||
assert.Len(t, tensors, inputLayerCount+1)
|
||||
err = gguf.Encode(f, KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": "name",
|
||||
"llama.context_length": uint32(32),
|
||||
"llama.embedding_length": uint32(4096),
|
||||
"llama.block_count": uint32(inputLayerCount),
|
||||
"llama.attention.head_count": uint32(32),
|
||||
"llama.attention.head_count_kv": uint32(32),
|
||||
"tokenizer.ggml.tokens": []string{" "},
|
||||
"tokenizer.ggml.scores": []float32{0},
|
||||
"tokenizer.ggml.token_type": []int32{0},
|
||||
}, tensors)
|
||||
require.NoError(t, err)
|
||||
|
||||
ggml, err := LoadModel(f.Name())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Simple CPU scenario
|
||||
gpus := []gpu.GpuInfo{
|
||||
{
|
||||
Library: "cpu",
|
||||
},
|
||||
}
|
||||
projectors := []string{}
|
||||
opts := api.DefaultOptions()
|
||||
t.Run("cpu", func(t *testing.T) {
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
assert.Equal(t, 0, estimate.Layers)
|
||||
assert.Equal(t, uint64(0), estimate.Graph)
|
||||
})
|
||||
|
||||
// derived from the dummy ggml file above
|
||||
graphPartialOffload := uint64(202377216)
|
||||
graphFullOffload := uint64(171968512)
|
||||
layerSize := uint64(33554436)
|
||||
projectorSize := uint64(0)
|
||||
memoryLayerOutput := uint64(4)
|
||||
|
||||
// Dual CUDA scenario with assymetry
|
||||
gpuMinimumMemory := uint64(2048)
|
||||
gpus = []gpu.GpuInfo{
|
||||
{
|
||||
Library: "cuda",
|
||||
MinimumMemory: gpuMinimumMemory,
|
||||
},
|
||||
{
|
||||
Library: "cuda",
|
||||
MinimumMemory: gpuMinimumMemory,
|
||||
},
|
||||
}
|
||||
// Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1
|
||||
for i, s := range []struct {
|
||||
layer0, layer1 uint64
|
||||
expect0, expect1 uint64
|
||||
}{
|
||||
{1, 1, 1, 1},
|
||||
{2, 1, 2, 1},
|
||||
{2, 2, 2, 2},
|
||||
{1, 2, 1, 2},
|
||||
{3, 3, 3, 3},
|
||||
{4, 4, 3, 3},
|
||||
{6, 6, 3, 3},
|
||||
{0, 3, 0, 3},
|
||||
} {
|
||||
t.Run(fmt.Sprintf("%v", s), func(t *testing.T) {
|
||||
gpus[0].FreeMemory = 0
|
||||
gpus[1].FreeMemory = 0
|
||||
gpus[0].FreeMemory += projectorSize
|
||||
if s.layer0 > 0 {
|
||||
gpus[0].FreeMemory += memoryLayerOutput
|
||||
} else {
|
||||
gpus[1].FreeMemory += memoryLayerOutput
|
||||
}
|
||||
gpus[0].FreeMemory += gpuMinimumMemory + layerSize + s.layer0*layerSize + 1
|
||||
gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1
|
||||
gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload)
|
||||
estimate := EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s)
|
||||
assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s)
|
||||
var layerSums uint64
|
||||
for _, b := range estimate.GPUSizes {
|
||||
layerSums += b
|
||||
}
|
||||
if estimate.Layers < inputLayerCount+1 {
|
||||
assert.Less(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
||||
assert.Equal(t, estimate.VRAMSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
||||
} else {
|
||||
assert.Equal(t, estimate.VRAMSize, estimate.TotalSize, "scenario %d: %v %+v", i, s, estimate)
|
||||
assert.Equal(t, estimate.TotalSize, layerSums, "scenario %d: %v %+v", i, s, estimate)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
diff --git a/llama.cpp b/llama.cpp
|
||||
index 40d2ec2c..f34eb79a 100644
|
||||
--- a/llama.cpp
|
||||
+++ b/llama.cpp
|
||||
@@ -6943,7 +6943,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||
cb(kq, "kq", il);
|
||||
|
||||
- if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||
+ if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX || model.arch == LLM_ARCH_QWEN2) {
|
||||
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
|
||||
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
|
||||
ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
|
||||
@@ -10,9 +10,9 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
@@ -82,8 +82,8 @@ func serversForGpu(info gpu.GpuInfo) []string {
|
||||
// glob workDir for files that start with ollama_
|
||||
availableServers := availableServers()
|
||||
requested := info.Library
|
||||
if info.Variant != gpu.CPUCapabilityNone {
|
||||
requested += "_" + info.Variant.String()
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
|
||||
servers := []string{}
|
||||
@@ -117,14 +117,14 @@ func serversForGpu(info gpu.GpuInfo) []string {
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if info.Library != "cpu" {
|
||||
variant := gpu.GetCPUCapability()
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != gpu.CPUCapabilityNone {
|
||||
if variant != "" {
|
||||
for cmp := range availableServers {
|
||||
if cmp == "cpu_"+variant.String() {
|
||||
if cmp == "cpu_"+variant {
|
||||
servers = append(servers, cmp)
|
||||
break
|
||||
}
|
||||
@@ -146,11 +146,11 @@ func serverForCpu() string {
|
||||
if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" {
|
||||
return "metal"
|
||||
}
|
||||
variant := gpu.GetCPUCapability()
|
||||
variant := gpu.GetCPUVariant()
|
||||
availableServers := availableServers()
|
||||
if variant != gpu.CPUCapabilityNone {
|
||||
if variant != "" {
|
||||
for cmp := range availableServers {
|
||||
if cmp == "cpu_"+variant.String() {
|
||||
if cmp == "cpu_"+variant {
|
||||
return cmp
|
||||
}
|
||||
}
|
||||
|
||||
110
llm/server.go
110
llm/server.go
@@ -37,9 +37,8 @@ type LlamaServer interface {
|
||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||
Close() error
|
||||
EstimatedVRAM() uint64 // Total VRAM across all GPUs
|
||||
EstimatedVRAM() uint64
|
||||
EstimatedTotal() uint64
|
||||
EstimatedVRAMByGPU(gpuID string) uint64
|
||||
}
|
||||
|
||||
// llmServer is an instance of the llama.cpp server
|
||||
@@ -50,12 +49,13 @@ type llmServer struct {
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
|
||||
estimate MemoryEstimate
|
||||
totalLayers uint64
|
||||
// gpuCount int
|
||||
gpus gpu.GpuInfoList // Recorded just before the model loaded, free space will be incorrect
|
||||
loadDuration time.Duration // Record how long it took the model to load
|
||||
loadProgress float32
|
||||
// TODO - this should be broken down by GPU
|
||||
estimatedVRAM uint64 // Estimated usage of VRAM by the loaded model
|
||||
estimatedTotal uint64 // Total size of model
|
||||
totalLayers uint64
|
||||
gpuCount int
|
||||
loadDuration time.Duration // Record how long it took the model to load
|
||||
loadProgress float32
|
||||
|
||||
sem *semaphore.Weighted
|
||||
}
|
||||
@@ -80,16 +80,17 @@ func LoadModel(model string) (*GGML, error) {
|
||||
func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, projectors []string, opts api.Options) (LlamaServer, error) {
|
||||
var err error
|
||||
var cpuRunner string
|
||||
var estimate MemoryEstimate
|
||||
var estimatedVRAM uint64
|
||||
var estimatedTotal uint64
|
||||
var systemMemory uint64
|
||||
gpuCount := len(gpus)
|
||||
if (len(gpus) == 1 && gpus[0].Library == "cpu") || opts.NumGPU == 0 {
|
||||
|
||||
// TODO evaluate system memory to see if we should block the load, or force an unload of another CPU runner
|
||||
|
||||
// If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info
|
||||
if opts.NumGPU == 0 {
|
||||
gpus = gpu.GetCPUInfo()
|
||||
}
|
||||
if len(gpus) == 1 && gpus[0].Library == "cpu" {
|
||||
cpuRunner = serverForCpu()
|
||||
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
gpuCount = 0
|
||||
_, _, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
} else {
|
||||
if gpus[0].Library == "metal" {
|
||||
memInfo, err := gpu.GetCPUMem()
|
||||
@@ -100,24 +101,24 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
slog.Debug("system memory", "total", format.HumanBytes2(systemMemory))
|
||||
}
|
||||
}
|
||||
estimate = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
var layers int
|
||||
layers, estimatedVRAM, estimatedTotal = EstimateGPULayers(gpus, ggml, projectors, opts)
|
||||
|
||||
switch {
|
||||
case gpus[0].Library == "metal" && estimate.VRAMSize > systemMemory:
|
||||
if gpus[0].Library == "metal" && estimatedVRAM > systemMemory {
|
||||
// disable partial offloading when model is greater than total system memory as this
|
||||
// can lead to locking up the system
|
||||
opts.NumGPU = 0
|
||||
case gpus[0].Library != "metal" && estimate.Layers == 0:
|
||||
} else if gpus[0].Library != "metal" && layers == 0 {
|
||||
// Don't bother loading into the GPU if no layers can fit
|
||||
cpuRunner = serverForCpu()
|
||||
gpus = gpu.GetCPUInfo()
|
||||
case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu":
|
||||
opts.NumGPU = estimate.Layers
|
||||
gpuCount = 0
|
||||
} else if opts.NumGPU < 0 && layers > 0 && gpus[0].Library != "cpu" {
|
||||
opts.NumGPU = layers
|
||||
}
|
||||
}
|
||||
|
||||
// Loop through potential servers
|
||||
finalErr := errors.New("no suitable llama servers found")
|
||||
finalErr := fmt.Errorf("no suitable llama servers found")
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
@@ -231,15 +232,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
|
||||
params = append(params, "--parallel", fmt.Sprintf("%d", numParallel))
|
||||
|
||||
if estimate.TensorSplit != "" {
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
if estimate.TensorSplit != "" {
|
||||
params = append(params, "--tensor-split", estimate.TensorSplit)
|
||||
}
|
||||
|
||||
for i := range len(servers) {
|
||||
for i := 0; i < len(servers); i++ {
|
||||
dir := availableServers[servers[i]]
|
||||
if dir == "" {
|
||||
// Shouldn't happen
|
||||
@@ -249,7 +242,8 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
}
|
||||
|
||||
if strings.HasPrefix(servers[i], "cpu") {
|
||||
gpus = gpu.GetCPUInfo()
|
||||
// TODO if we tried a gpu runner first, and it failed, record the error and bubble that back up
|
||||
gpuCount = 0
|
||||
}
|
||||
|
||||
// Find an availableServers port, retry on each iteration in case the failure was a port conflict race
|
||||
@@ -290,7 +284,7 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
|
||||
server := filepath.Join(dir, "ollama_llama_server")
|
||||
if runtime.GOOS == "windows" {
|
||||
server += ".exe"
|
||||
server = server + ".exe"
|
||||
}
|
||||
|
||||
// Detect tmp cleaners wiping out the file
|
||||
@@ -305,26 +299,23 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
}
|
||||
|
||||
s := &llmServer{
|
||||
port: port,
|
||||
cmd: exec.Command(server, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
estimate: estimate,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: ggml.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
port: port,
|
||||
cmd: exec.Command(server, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
estimatedVRAM: estimatedVRAM,
|
||||
estimatedTotal: estimatedTotal,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: ggml.KV().BlockCount() + 1,
|
||||
gpuCount: gpuCount,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
|
||||
envWorkarounds := [][2]string{}
|
||||
for _, gpu := range gpus {
|
||||
envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...)
|
||||
}
|
||||
visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv()
|
||||
visibleDevicesEnv, visibleDevicesEnvVal := gpu.GpuInfoList(gpus).GetVisibleDevicesEnv()
|
||||
pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator))
|
||||
|
||||
// Update or add the path and visible devices variable with our adjusted version
|
||||
@@ -338,12 +329,6 @@ func NewLlamaServer(gpus gpu.GpuInfoList, model string, ggml *GGML, adapters, pr
|
||||
} else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) {
|
||||
s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal
|
||||
devicesNeeded = false
|
||||
} else if len(envWorkarounds) != 0 {
|
||||
for _, kv := range envWorkarounds {
|
||||
if strings.EqualFold(cmp[0], kv[0]) {
|
||||
s.cmd.Env[i] = kv[0] + "=" + kv[1]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if pathNeeded {
|
||||
@@ -474,7 +459,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ServerStatusNotResponding, errors.New("server not responding")
|
||||
return ServerStatusNotResponding, fmt.Errorf("server not responding")
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
||||
}
|
||||
@@ -621,7 +606,7 @@ array ::=
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\\x7F\x00-\x1F] |
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
@@ -1019,20 +1004,11 @@ func (s *llmServer) Close() error {
|
||||
}
|
||||
|
||||
func (s *llmServer) EstimatedVRAM() uint64 {
|
||||
return s.estimate.VRAMSize
|
||||
return s.estimatedVRAM
|
||||
}
|
||||
|
||||
func (s *llmServer) EstimatedTotal() uint64 {
|
||||
return s.estimate.TotalSize
|
||||
}
|
||||
|
||||
func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 {
|
||||
for i, gpu := range s.gpus {
|
||||
if gpu.ID == gpuID {
|
||||
return s.estimate.GPUSizes[i]
|
||||
}
|
||||
}
|
||||
return 0
|
||||
return s.estimatedTotal
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
|
||||
@@ -178,6 +178,9 @@ func fromRequest(r ChatCompletionRequest) api.ChatRequest {
|
||||
|
||||
if r.Seed != nil {
|
||||
options["seed"] = *r.Seed
|
||||
|
||||
// temperature=0 is required for reproducible outputs
|
||||
options["temperature"] = 0.0
|
||||
}
|
||||
|
||||
if r.FrequencyPenalty != nil {
|
||||
@@ -242,6 +245,7 @@ func (w *writer) writeResponse(data []byte) (int, error) {
|
||||
d, err := json.Marshal(toChunk(w.id, chatResponse))
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
||||
}
|
||||
|
||||
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
@@ -8,9 +8,7 @@ import (
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
"golang.org/x/text/transform"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
type File struct {
|
||||
@@ -71,11 +69,14 @@ func ParseFile(r io.Reader) (*File, error) {
|
||||
var b bytes.Buffer
|
||||
var role string
|
||||
|
||||
var lineCount int
|
||||
var linePos int
|
||||
|
||||
var utf16 bool
|
||||
|
||||
var f File
|
||||
|
||||
tr := unicode.BOMOverride(unicode.UTF8.NewDecoder())
|
||||
br := bufio.NewReader(transform.NewReader(r, tr))
|
||||
|
||||
br := bufio.NewReader(r)
|
||||
for {
|
||||
r, _, err := br.ReadRune()
|
||||
if errors.Is(err, io.EOF) {
|
||||
@@ -84,6 +85,17 @@ func ParseFile(r io.Reader) (*File, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// the utf16 byte order mark will be read as "unreadable" by ReadRune()
|
||||
if isUnreadable(r) && lineCount == 0 && linePos == 0 {
|
||||
utf16 = true
|
||||
continue
|
||||
}
|
||||
|
||||
// skip the second byte if we're reading utf16
|
||||
if utf16 && r == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
next, r, err := parseRuneForState(r, curr)
|
||||
if errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, fmt.Errorf("%w: %s", err, b.String())
|
||||
@@ -91,6 +103,13 @@ func ParseFile(r io.Reader) (*File, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if isNewline(r) {
|
||||
lineCount++
|
||||
linePos = 0
|
||||
} else {
|
||||
linePos++
|
||||
}
|
||||
|
||||
// process the state transition, some transitions need to be intercepted and redirected
|
||||
if next != curr {
|
||||
switch curr {
|
||||
@@ -290,6 +309,10 @@ func isNewline(r rune) bool {
|
||||
return r == '\r' || r == '\n'
|
||||
}
|
||||
|
||||
func isUnreadable(r rune) bool {
|
||||
return r == unicode.ReplacementChar
|
||||
}
|
||||
|
||||
func isValidMessageRole(role string) bool {
|
||||
return role == "system" || role == "user" || role == "assistant"
|
||||
}
|
||||
|
||||
@@ -10,9 +10,6 @@ import (
|
||||
"unicode/utf16"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/text/encoding"
|
||||
"golang.org/x/text/encoding/unicode"
|
||||
)
|
||||
|
||||
func TestParseFileFile(t *testing.T) {
|
||||
@@ -28,7 +25,7 @@ TEMPLATE template1
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
modelfile, err := ParseFile(reader)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expectedCommands := []Command{
|
||||
{Name: "model", Args: "model1"},
|
||||
@@ -91,7 +88,7 @@ func TestParseFileFrom(t *testing.T) {
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
require.ErrorIs(t, err, c.err)
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
@@ -108,7 +105,7 @@ PARAMETER param1
|
||||
reader := strings.NewReader(input)
|
||||
|
||||
_, err := ParseFile(reader)
|
||||
require.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
assert.ErrorIs(t, err, io.ErrUnexpectedEOF)
|
||||
}
|
||||
|
||||
func TestParseFileBadCommand(t *testing.T) {
|
||||
@@ -117,7 +114,8 @@ FROM foo
|
||||
BADCOMMAND param1 value1
|
||||
`
|
||||
_, err := ParseFile(strings.NewReader(input))
|
||||
require.ErrorIs(t, err, errInvalidCommand)
|
||||
assert.ErrorIs(t, err, errInvalidCommand)
|
||||
|
||||
}
|
||||
|
||||
func TestParseFileMessages(t *testing.T) {
|
||||
@@ -203,7 +201,7 @@ MESSAGE system`,
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
require.ErrorIs(t, err, c.err)
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
@@ -357,7 +355,7 @@ TEMPLATE """
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.multiline))
|
||||
require.ErrorIs(t, err, c.err)
|
||||
assert.ErrorIs(t, err, c.err)
|
||||
if modelfile != nil {
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
}
|
||||
@@ -415,7 +413,7 @@ func TestParseFileParameters(t *testing.T) {
|
||||
fmt.Fprintln(&b, "FROM foo")
|
||||
fmt.Fprintln(&b, "PARAMETER", k)
|
||||
modelfile, err := ParseFile(&b)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []Command{
|
||||
{Name: "model", Args: "foo"},
|
||||
@@ -444,7 +442,7 @@ FROM foo
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c.input))
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, c.expected, modelfile.Commands)
|
||||
})
|
||||
}
|
||||
@@ -503,14 +501,15 @@ SYSTEM ""
|
||||
for _, c := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
modelfile, err := ParseFile(strings.NewReader(c))
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
modelfile2, err := ParseFile(strings.NewReader(modelfile.String()))
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, modelfile, modelfile2)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseFileUTF16ParseFile(t *testing.T) {
|
||||
@@ -519,6 +518,14 @@ PARAMETER param1 1
|
||||
PARAMETER param2 4096
|
||||
SYSTEM You are a utf16 file.
|
||||
`
|
||||
// simulate a utf16 le file
|
||||
utf16File := utf16.Encode(append([]rune{'\ufffe'}, []rune(data)...))
|
||||
buf := new(bytes.Buffer)
|
||||
err := binary.Write(buf, binary.LittleEndian, utf16File)
|
||||
assert.NoError(t, err)
|
||||
|
||||
actual, err := ParseFile(buf)
|
||||
assert.NoError(t, err)
|
||||
|
||||
expected := []Command{
|
||||
{Name: "model", Args: "bob"},
|
||||
@@ -527,52 +534,14 @@ SYSTEM You are a utf16 file.
|
||||
{Name: "system", Args: "You are a utf16 file."},
|
||||
}
|
||||
|
||||
t.Run("le", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
require.NoError(t, binary.Write(&b, binary.LittleEndian, []byte{0xff, 0xfe}))
|
||||
require.NoError(t, binary.Write(&b, binary.LittleEndian, utf16.Encode([]rune(data))))
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
|
||||
actual, err := ParseFile(&b)
|
||||
require.NoError(t, err)
|
||||
// simulate a utf16 be file
|
||||
buf = new(bytes.Buffer)
|
||||
err = binary.Write(buf, binary.BigEndian, utf16File)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
})
|
||||
|
||||
t.Run("be", func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
require.NoError(t, binary.Write(&b, binary.BigEndian, []byte{0xfe, 0xff}))
|
||||
require.NoError(t, binary.Write(&b, binary.BigEndian, utf16.Encode([]rune(data))))
|
||||
|
||||
actual, err := ParseFile(&b)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseMultiByte(t *testing.T) {
|
||||
input := `FROM test
|
||||
SYSTEM 你好👋`
|
||||
|
||||
expect := []Command{
|
||||
{Name: "model", Args: "test"},
|
||||
{Name: "system", Args: "你好👋"},
|
||||
}
|
||||
|
||||
encodings := []encoding.Encoding{
|
||||
unicode.UTF8,
|
||||
unicode.UTF16(unicode.LittleEndian, unicode.UseBOM),
|
||||
unicode.UTF16(unicode.BigEndian, unicode.UseBOM),
|
||||
}
|
||||
|
||||
for _, encoding := range encodings {
|
||||
t.Run(fmt.Sprintf("%s", encoding), func(t *testing.T) {
|
||||
s, err := encoding.NewEncoder().String(input)
|
||||
require.NoError(t, err)
|
||||
|
||||
actual, err := ParseFile(strings.NewReader(s))
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, expect, actual.Commands)
|
||||
})
|
||||
}
|
||||
actual, err = ParseFile(buf)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expected, actual.Commands)
|
||||
}
|
||||
|
||||
@@ -59,7 +59,7 @@ func (p *Progress) StopAndClear() bool {
|
||||
stopped := p.stop()
|
||||
if stopped {
|
||||
// clear all progress lines
|
||||
for i := range p.pos {
|
||||
for i := 0; i < p.pos; i++ {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func (p *Progress) render() {
|
||||
defer fmt.Fprint(p.w, "\033[?25h")
|
||||
|
||||
// clear already rendered progress lines
|
||||
for i := range p.pos {
|
||||
for i := 0; i < p.pos; i++ {
|
||||
if i > 0 {
|
||||
fmt.Fprint(p.w, "\033[A")
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ func (b *Buffer) GetLineSpacing(line int) bool {
|
||||
}
|
||||
|
||||
return hasSpace.(bool)
|
||||
|
||||
}
|
||||
|
||||
func (b *Buffer) MoveLeft() {
|
||||
@@ -116,12 +117,15 @@ func (b *Buffer) MoveRight() {
|
||||
|
||||
if b.DisplayPos%b.LineWidth == 0 {
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||
|
||||
} else if (b.DisplayPos-rLength)%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())+rLength))
|
||||
b.DisplayPos += 1
|
||||
|
||||
} else if b.LineHasSpace.Size() > 0 && b.DisplayPos%b.LineWidth == b.LineWidth-1 && hasSpace {
|
||||
fmt.Printf(CursorDown + CursorBOL + cursorRightN(len(b.Prompt.prompt())))
|
||||
b.DisplayPos += 1
|
||||
|
||||
} else {
|
||||
fmt.Print(cursorRightN(rLength))
|
||||
}
|
||||
@@ -150,7 +154,7 @@ func (b *Buffer) MoveToStart() {
|
||||
if b.Pos > 0 {
|
||||
currLine := b.DisplayPos / b.LineWidth
|
||||
if currLine > 0 {
|
||||
for range currLine {
|
||||
for cnt := 0; cnt < currLine; cnt++ {
|
||||
fmt.Print(CursorUp)
|
||||
}
|
||||
}
|
||||
@@ -165,7 +169,7 @@ func (b *Buffer) MoveToEnd() {
|
||||
currLine := b.DisplayPos / b.LineWidth
|
||||
totalLines := b.DisplaySize() / b.LineWidth
|
||||
if currLine < totalLines {
|
||||
for range totalLines - currLine {
|
||||
for cnt := 0; cnt < totalLines-currLine; cnt++ {
|
||||
fmt.Print(CursorDown)
|
||||
}
|
||||
remainder := b.DisplaySize() % b.LineWidth
|
||||
@@ -181,7 +185,7 @@ func (b *Buffer) MoveToEnd() {
|
||||
|
||||
func (b *Buffer) DisplaySize() int {
|
||||
sum := 0
|
||||
for i := range b.Buf.Size() {
|
||||
for i := 0; i < b.Buf.Size(); i++ {
|
||||
if e, ok := b.Buf.Get(i); ok {
|
||||
if r, ok := e.(rune); ok {
|
||||
sum += runewidth.RuneWidth(r)
|
||||
@@ -193,6 +197,7 @@ func (b *Buffer) DisplaySize() int {
|
||||
}
|
||||
|
||||
func (b *Buffer) Add(r rune) {
|
||||
|
||||
if b.Pos == b.Buf.Size() {
|
||||
b.AddChar(r, false)
|
||||
} else {
|
||||
@@ -205,6 +210,7 @@ func (b *Buffer) AddChar(r rune, insert bool) {
|
||||
b.DisplayPos += rLength
|
||||
|
||||
if b.Pos > 0 {
|
||||
|
||||
if b.DisplayPos%b.LineWidth == 0 {
|
||||
fmt.Printf("%c", r)
|
||||
fmt.Printf("\n%s", b.Prompt.AltPrompt)
|
||||
@@ -229,6 +235,7 @@ func (b *Buffer) AddChar(r rune, insert bool) {
|
||||
} else {
|
||||
b.LineHasSpace.Add(true)
|
||||
}
|
||||
|
||||
} else {
|
||||
fmt.Printf("%c", r)
|
||||
}
|
||||
@@ -349,6 +356,7 @@ func (b *Buffer) drawRemaining() {
|
||||
|
||||
func (b *Buffer) Remove() {
|
||||
if b.Buf.Size() > 0 && b.Pos > 0 {
|
||||
|
||||
if e, ok := b.Buf.Get(b.Pos - 1); ok {
|
||||
if r, ok := e.(rune); ok {
|
||||
rLength := runewidth.RuneWidth(r)
|
||||
@@ -374,6 +382,7 @@ func (b *Buffer) Remove() {
|
||||
} else {
|
||||
fmt.Print(" " + CursorLeft)
|
||||
}
|
||||
|
||||
} else if (b.DisplayPos-rLength)%b.LineWidth == 0 && hasSpace {
|
||||
fmt.Printf(CursorBOL + ClearToEOL)
|
||||
fmt.Printf(CursorUp + CursorBOL + cursorRightN(b.Width))
|
||||
@@ -382,9 +391,10 @@ func (b *Buffer) Remove() {
|
||||
b.LineHasSpace.Remove(b.DisplayPos/b.LineWidth - 1)
|
||||
}
|
||||
b.DisplayPos -= 1
|
||||
|
||||
} else {
|
||||
fmt.Print(cursorLeftN(rLength))
|
||||
for range rLength {
|
||||
for i := 0; i < rLength; i++ {
|
||||
fmt.Print(" ")
|
||||
}
|
||||
fmt.Print(cursorLeftN(rLength))
|
||||
@@ -441,7 +451,7 @@ func (b *Buffer) DeleteBefore() {
|
||||
func (b *Buffer) DeleteRemaining() {
|
||||
if b.DisplaySize() > 0 && b.Pos < b.DisplaySize() {
|
||||
charsToDel := b.Buf.Size() - b.Pos
|
||||
for range charsToDel {
|
||||
for cnt := 0; cnt < charsToDel; cnt++ {
|
||||
b.Delete()
|
||||
}
|
||||
}
|
||||
@@ -485,7 +495,7 @@ func (b *Buffer) ClearScreen() {
|
||||
if currPos > 0 {
|
||||
targetLine := currPos / b.LineWidth
|
||||
if targetLine > 0 {
|
||||
for range targetLine {
|
||||
for cnt := 0; cnt < targetLine; cnt++ {
|
||||
fmt.Print(CursorDown)
|
||||
}
|
||||
}
|
||||
@@ -515,7 +525,7 @@ func (b *Buffer) Replace(r []rune) {
|
||||
|
||||
fmt.Printf(CursorBOL + ClearToEOL)
|
||||
|
||||
for range lineNums {
|
||||
for i := 0; i < lineNums; i++ {
|
||||
fmt.Print(CursorUp + CursorBOL + ClearToEOL)
|
||||
}
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ func (h *History) Add(l []rune) {
|
||||
func (h *History) Compact() {
|
||||
s := h.Buf.Size()
|
||||
if s > h.Limit {
|
||||
for range s - h.Limit {
|
||||
for cnt := 0; cnt < s-h.Limit; cnt++ {
|
||||
h.Buf.Remove(0)
|
||||
}
|
||||
}
|
||||
@@ -139,7 +139,7 @@ func (h *History) Save() error {
|
||||
defer f.Close()
|
||||
|
||||
buf := bufio.NewWriter(f)
|
||||
for cnt := range h.Size() {
|
||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
||||
v, _ := h.Buf.Get(cnt)
|
||||
line, _ := v.([]rune)
|
||||
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
type Prompt struct {
|
||||
@@ -62,7 +63,7 @@ func New(prompt Prompt) (*Instance, error) {
|
||||
|
||||
func (i *Instance) Readline() (string, error) {
|
||||
if !i.Terminal.rawmode {
|
||||
fd := os.Stdin.Fd()
|
||||
fd := int(syscall.Stdin)
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -79,8 +80,8 @@ func (i *Instance) Readline() (string, error) {
|
||||
fmt.Print(prompt)
|
||||
|
||||
defer func() {
|
||||
fd := os.Stdin.Fd()
|
||||
//nolint:errcheck
|
||||
fd := int(syscall.Stdin)
|
||||
// nolint: errcheck
|
||||
UnsetRawMode(fd, i.Terminal.termios)
|
||||
i.Terminal.rawmode = false
|
||||
}()
|
||||
@@ -135,7 +136,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.MoveRight()
|
||||
case CharBracketedPaste:
|
||||
var code string
|
||||
for range 3 {
|
||||
for cnt := 0; cnt < 3; cnt++ {
|
||||
r, err = i.Terminal.Read()
|
||||
if err != nil {
|
||||
return "", io.EOF
|
||||
@@ -197,7 +198,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
buf.Remove()
|
||||
case CharTab:
|
||||
// todo: convert back to real tabs
|
||||
for range 8 {
|
||||
for cnt := 0; cnt < 8; cnt++ {
|
||||
buf.Add(' ')
|
||||
}
|
||||
case CharDelete:
|
||||
@@ -215,7 +216,7 @@ func (i *Instance) Readline() (string, error) {
|
||||
case CharCtrlW:
|
||||
buf.DeleteWord()
|
||||
case CharCtrlZ:
|
||||
fd := os.Stdin.Fd()
|
||||
fd := int(syscall.Stdin)
|
||||
return handleCharCtrlZ(fd, i.Terminal.termios)
|
||||
case CharEnter, CharCtrlJ:
|
||||
output := buf.String()
|
||||
@@ -247,7 +248,7 @@ func (i *Instance) HistoryDisable() {
|
||||
}
|
||||
|
||||
func NewTerminal() (*Terminal, error) {
|
||||
fd := os.Stdin.Fd()
|
||||
fd := int(syscall.Stdin)
|
||||
termios, err := SetRawMode(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func handleCharCtrlZ(fd uintptr, termios any) (string, error) {
|
||||
func handleCharCtrlZ(fd int, termios any) (string, error) {
|
||||
t := termios.(*Termios)
|
||||
if err := UnsetRawMode(fd, t); err != nil {
|
||||
return "", err
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
package readline
|
||||
|
||||
func handleCharCtrlZ(fd uintptr, state any) (string, error) {
|
||||
func handleCharCtrlZ(fd int, state any) (string, error) {
|
||||
// not supported
|
||||
return "", nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
|
||||
type Termios syscall.Termios
|
||||
|
||||
func SetRawMode(fd uintptr) (*Termios, error) {
|
||||
func SetRawMode(fd int) (*Termios, error) {
|
||||
termios, err := getTermios(fd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -25,13 +25,13 @@ func SetRawMode(fd uintptr) (*Termios, error) {
|
||||
return termios, setTermios(fd, &newTermios)
|
||||
}
|
||||
|
||||
func UnsetRawMode(fd uintptr, termios any) error {
|
||||
func UnsetRawMode(fd int, termios any) error {
|
||||
t := termios.(*Termios)
|
||||
return setTermios(fd, t)
|
||||
}
|
||||
|
||||
// IsTerminal returns true if the given file descriptor is a terminal.
|
||||
func IsTerminal(fd uintptr) bool {
|
||||
func IsTerminal(fd int) bool {
|
||||
_, err := getTermios(fd)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
@@ -7,17 +7,17 @@ import (
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func getTermios(fd uintptr) (*Termios, error) {
|
||||
func getTermios(fd int) (*Termios, error) {
|
||||
termios := new(Termios)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCGETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return nil, err
|
||||
}
|
||||
return termios, nil
|
||||
}
|
||||
|
||||
func setTermios(fd uintptr, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
func setTermios(fd int, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), syscall.TIOCSETA, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -10,17 +10,17 @@ import (
|
||||
const tcgets = 0x5401
|
||||
const tcsets = 0x5402
|
||||
|
||||
func getTermios(fd uintptr) (*Termios, error) {
|
||||
func getTermios(fd int) (*Termios, error) {
|
||||
termios := new(Termios)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcgets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return nil, err
|
||||
}
|
||||
return termios, nil
|
||||
}
|
||||
|
||||
func setTermios(fd uintptr, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, fd, tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
func setTermios(fd int, termios *Termios) error {
|
||||
_, _, err := syscall.Syscall6(syscall.SYS_IOCTL, uintptr(fd), tcsets, uintptr(unsafe.Pointer(termios)), 0, 0, 0)
|
||||
if err != 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -9,13 +9,13 @@ type State struct {
|
||||
}
|
||||
|
||||
// IsTerminal checks if the given file descriptor is associated with a terminal
|
||||
func IsTerminal(fd uintptr) bool {
|
||||
func IsTerminal(fd int) bool {
|
||||
var st uint32
|
||||
err := windows.GetConsoleMode(windows.Handle(fd), &st)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func SetRawMode(fd uintptr) (*State, error) {
|
||||
func SetRawMode(fd int) (*State, error) {
|
||||
var st uint32
|
||||
if err := windows.GetConsoleMode(windows.Handle(fd), &st); err != nil {
|
||||
return nil, err
|
||||
@@ -32,7 +32,7 @@ func SetRawMode(fd uintptr) (*State, error) {
|
||||
return &State{st}, nil
|
||||
}
|
||||
|
||||
func UnsetRawMode(fd uintptr, state any) error {
|
||||
func UnsetRawMode(fd int, state any) error {
|
||||
s := state.(*State)
|
||||
return windows.SetConsoleMode(windows.Handle(fd), s.mode)
|
||||
}
|
||||
|
||||
@@ -159,8 +159,8 @@ check_gpu() {
|
||||
esac ;;
|
||||
lshw)
|
||||
case $2 in
|
||||
nvidia) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
|
||||
amdgpu) available lshw && $SUDO lshw -c display -numeric -disable network | grep -q 'vendor: .* \[1002\]' || return 1 ;;
|
||||
nvidia) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[10DE\]' || return 1 ;;
|
||||
amdgpu) available lshw && $SUDO lshw -c display -numeric | grep -q 'vendor: .* \[1002\]' || return 1 ;;
|
||||
esac ;;
|
||||
nvidia-smi) available nvidia-smi || return 1 ;;
|
||||
esac
|
||||
|
||||
@@ -340,17 +340,17 @@ type downloadOpts struct {
|
||||
}
|
||||
|
||||
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ error) {
|
||||
func downloadBlob(ctx context.Context, opts downloadOpts) error {
|
||||
fp, err := GetBlobsPath(opts.digest)
|
||||
if err != nil {
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
fi, err := os.Stat(fp)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
case err != nil:
|
||||
return false, err
|
||||
return err
|
||||
default:
|
||||
opts.fn(api.ProgressResponse{
|
||||
Status: fmt.Sprintf("pulling %s", opts.digest[7:19]),
|
||||
@@ -359,7 +359,7 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
Completed: fi.Size(),
|
||||
})
|
||||
|
||||
return true, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
data, ok := blobDownloadManager.LoadOrStore(opts.digest, &blobDownload{Name: fp, Digest: opts.digest})
|
||||
@@ -369,12 +369,12 @@ func downloadBlob(ctx context.Context, opts downloadOpts) (cacheHit bool, _ erro
|
||||
requestURL = requestURL.JoinPath("v2", opts.mp.GetNamespaceRepository(), "blobs", opts.digest)
|
||||
if err := download.Prepare(ctx, requestURL, opts.regOpts); err != nil {
|
||||
blobDownloadManager.Delete(opts.digest)
|
||||
return false, err
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
// nolint: contextcheck
|
||||
go download.Run(context.Background(), requestURL, opts.regOpts)
|
||||
}
|
||||
|
||||
return false, download.Wait(ctx, opts.fn)
|
||||
return download.Wait(ctx, opts.fn)
|
||||
}
|
||||
|
||||
@@ -18,16 +18,17 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/auth"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -314,7 +315,7 @@ func realpath(rel, from string) string {
|
||||
return abspath
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, modelfile *parser.File, fn func(resp api.ProgressResponse)) (err error) {
|
||||
config := ConfigV2{
|
||||
OS: "linux",
|
||||
Architecture: "amd64",
|
||||
@@ -332,7 +333,7 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
|
||||
switch c.Name {
|
||||
case "model", "adapter":
|
||||
var baseLayers []*layerGGML
|
||||
var baseLayers []*layerWithGGML
|
||||
if name := model.ParseName(c.Args); name.IsValid() {
|
||||
baseLayers, err = parseFromModel(ctx, name, fn)
|
||||
if err != nil {
|
||||
@@ -439,27 +440,19 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
layers = append(layers, baseLayer.Layer)
|
||||
}
|
||||
case "license", "template", "system":
|
||||
if c.Name != "license" {
|
||||
// replace
|
||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||
if layer.MediaType != mediatype {
|
||||
return false
|
||||
}
|
||||
|
||||
if err := layer.Remove(); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
blob := strings.NewReader(c.Args)
|
||||
layer, err := NewLayer(blob, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Name != "license" {
|
||||
// replace
|
||||
layers = slices.DeleteFunc(layers, func(layer *Layer) bool {
|
||||
return layer.MediaType == mediatype
|
||||
})
|
||||
}
|
||||
|
||||
layers = append(layers, layer)
|
||||
case "message":
|
||||
role, content, ok := strings.Cut(c.Args, ": ")
|
||||
@@ -578,15 +571,26 @@ func CreateModel(ctx context.Context, name model.Name, modelFileDir, quantizatio
|
||||
}
|
||||
}
|
||||
|
||||
old, _ := ParseNamedManifest(name)
|
||||
unref := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range manifest.Layers {
|
||||
if !slices.Contains(digests, layer.Digest) {
|
||||
unref[layer.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
if manifest.Config.Digest != layer.Digest {
|
||||
unref[manifest.Config.Digest] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "writing manifest"})
|
||||
if err := WriteManifest(name, layer, layers); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !envconfig.NoPrune && old != nil {
|
||||
if err := old.RemoveLayers(); err != nil {
|
||||
if !envconfig.NoPrune {
|
||||
if err := deleteUnusedLayers(nil, unref); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -658,7 +662,7 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{})
|
||||
// save (i.e. delete from the deleteMap) any files used in other manifests
|
||||
manifest, _, err := GetManifest(fmp)
|
||||
if err != nil {
|
||||
//nolint:nilerr
|
||||
// nolint: nilerr
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -853,27 +857,23 @@ func PullModel(ctx context.Context, name string, regOpts *registryOptions, fn fu
|
||||
layers = append(layers, manifest.Layers...)
|
||||
layers = append(layers, manifest.Config)
|
||||
|
||||
skipVerify := make(map[string]bool)
|
||||
for _, layer := range layers {
|
||||
cacheHit, err := downloadBlob(ctx, downloadOpts{
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
})
|
||||
if err != nil {
|
||||
if err := downloadBlob(
|
||||
ctx,
|
||||
downloadOpts{
|
||||
mp: mp,
|
||||
digest: layer.Digest,
|
||||
regOpts: regOpts,
|
||||
fn: fn,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
skipVerify[layer.Digest] = cacheHit
|
||||
delete(deleteMap, layer.Digest)
|
||||
}
|
||||
delete(deleteMap, manifest.Config.Digest)
|
||||
|
||||
fn(api.ProgressResponse{Status: "verifying sha256 digest"})
|
||||
for _, layer := range layers {
|
||||
if skipVerify[layer.Digest] {
|
||||
continue
|
||||
}
|
||||
if err := verifyBlob(layer.Digest); err != nil {
|
||||
if errors.Is(err, errDigestMismatch) {
|
||||
// something went wrong, delete the blob
|
||||
@@ -960,6 +960,7 @@ var errUnauthorized = fmt.Errorf("unauthorized: access denied")
|
||||
func getTokenSubject(token string) string {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
slog.Error("jwt token does not contain 3 parts")
|
||||
return ""
|
||||
}
|
||||
|
||||
@@ -987,7 +988,7 @@ func getTokenSubject(token string) string {
|
||||
|
||||
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *registryOptions) (*http.Response, error) {
|
||||
anonymous := true // access will default to anonymous if no user is found associated with the public key
|
||||
for range 2 {
|
||||
for i := 0; i < 2; i++ {
|
||||
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
|
||||
if err != nil {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
@@ -34,6 +34,12 @@ func (m *Manifest) Remove() error {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if err := layer.Remove(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -42,18 +48,6 @@ func (m *Manifest) Remove() error {
|
||||
return PruneDirectory(manifests)
|
||||
}
|
||||
|
||||
func (m *Manifest) RemoveLayers() error {
|
||||
for _, layer := range append(m.Layers, m.Config) {
|
||||
if err := layer.Remove(); errors.Is(err, os.ErrNotExist) {
|
||||
slog.Debug("layer does not exist", "digest", layer.Digest)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
if !n.IsFullyQualified() {
|
||||
return nil, model.Unqualified(n)
|
||||
@@ -91,31 +85,30 @@ func ParseNamedManifest(n model.Name) (*Manifest, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func WriteManifest(name model.Name, config *Layer, layers []*Layer) error {
|
||||
manifests, err := GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
p := filepath.Join(manifests, name.Filepath())
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err := os.Create(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
m := ManifestV2{
|
||||
func WriteManifest(name string, config *Layer, layers []*Layer) error {
|
||||
manifest := ManifestV2{
|
||||
SchemaVersion: 2,
|
||||
MediaType: "application/vnd.docker.distribution.manifest.v2+json",
|
||||
Config: config,
|
||||
Layers: layers,
|
||||
}
|
||||
|
||||
return json.NewEncoder(f).Encode(m)
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(manifest); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
modelpath := ParseModelPath(name)
|
||||
manifestPath, err := modelpath.GetManifestPath()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
|
||||
}
|
||||
|
||||
func Manifests() (map[model.Name]*Manifest, error) {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
@@ -108,7 +107,6 @@ func TestManifests(t *testing.T) {
|
||||
t.Run(n, func(t *testing.T) {
|
||||
d := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", d)
|
||||
envconfig.LoadConfig()
|
||||
|
||||
for _, p := range wants.ps {
|
||||
createManifest(t, d, p)
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -15,26 +14,27 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/templates"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
var intermediateBlobs map[string]string = make(map[string]string)
|
||||
|
||||
type layerGGML struct {
|
||||
type layerWithGGML struct {
|
||||
*Layer
|
||||
*llm.GGML
|
||||
}
|
||||
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
m, err := ParseNamedManifest(name)
|
||||
func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
modelpath := ParseModelPath(name.String())
|
||||
manifest, _, err := GetManifest(modelpath)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
if err := PullModel(ctx, name.String(), ®istryOptions{}, fn); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
m, err = ParseNamedManifest(name)
|
||||
modelpath = ParseModelPath(name.String())
|
||||
manifest, _, err = GetManifest(modelpath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -42,8 +42,8 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, layer := range m.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, name.DisplayShortest())
|
||||
for _, layer := range manifest.Layers {
|
||||
layer, err := NewLayerFromLayer(layer.Digest, layer.MediaType, modelpath.GetShortTagname())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -68,16 +68,17 @@ func parseFromModel(ctx context.Context, name model.Name, fn func(api.ProgressRe
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, ggml})
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
default:
|
||||
layers = append(layers, &layerGGML{layer, nil})
|
||||
layers = append(layers, &layerWithGGML{layer, nil})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
stat, err := file.Stat()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -181,13 +182,13 @@ func parseFromZipFile(_ context.Context, file *os.File, digest string, fn func(a
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, ggml})
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
|
||||
intermediateBlobs[digest] = layer.Digest
|
||||
return detectChatTemplate(layers)
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerGGML, err error) {
|
||||
func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(api.ProgressResponse)) (layers []*layerWithGGML, err error) {
|
||||
sr := io.NewSectionReader(file, 0, 512)
|
||||
contentType, err := detectContentType(sr)
|
||||
if err != nil {
|
||||
@@ -229,30 +230,10 @@ func parseFromFile(ctx context.Context, file *os.File, digest string, fn func(ap
|
||||
return nil, err
|
||||
}
|
||||
|
||||
layers = append(layers, &layerGGML{layer, ggml})
|
||||
layers = append(layers, &layerWithGGML{layer, ggml})
|
||||
offset = n
|
||||
}
|
||||
|
||||
return detectChatTemplate(layers)
|
||||
}
|
||||
|
||||
func detectChatTemplate(layers []*layerGGML) ([]*layerGGML, error) {
|
||||
for _, layer := range layers {
|
||||
if s := layer.GGML.KV().ChatTemplate(); s != "" {
|
||||
if t, err := templates.NamedTemplate(s); err != nil {
|
||||
slog.Debug("template detection", "error", err)
|
||||
} else {
|
||||
tmpl, err := NewLayer(t.Reader(), "application/vnd.ollama.image.template")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tmpl.status = fmt.Sprintf("using autodetected template %s", t.Name)
|
||||
layers = append(layers, &layerGGML{tmpl, nil})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return layers, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -8,8 +8,6 @@ import (
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
type ModelPath struct {
|
||||
@@ -106,7 +104,14 @@ func (mp ModelPath) GetShortTagname() string {
|
||||
// modelsDir returns the value of the OLLAMA_MODELS environment variable or the user's home directory if OLLAMA_MODELS is not set.
|
||||
// The models directory is where Ollama stores its model files and manifests.
|
||||
func modelsDir() (string, error) {
|
||||
return envconfig.ModelsDir, nil
|
||||
if models, exists := os.LookupEnv("OLLAMA_MODELS"); exists {
|
||||
return models, nil
|
||||
}
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return filepath.Join(home, ".ollama", "models"), nil
|
||||
}
|
||||
|
||||
// GetManifestPath returns the path to the manifest file for the given model path, it is up to the caller to create the directory if it does not exist.
|
||||
|
||||
@@ -6,15 +6,12 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestGetBlobsPath(t *testing.T) {
|
||||
// GetBlobsPath expects an actual directory to exist
|
||||
dir, err := os.MkdirTemp("", "ollama-test")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
defer os.RemoveAll(dir)
|
||||
|
||||
tests := []struct {
|
||||
@@ -63,11 +60,10 @@ func TestGetBlobsPath(t *testing.T) {
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", dir)
|
||||
envconfig.LoadConfig()
|
||||
|
||||
got, err := GetBlobsPath(tc.digest)
|
||||
|
||||
require.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.ErrorIs(t, tc.err, err, tc.name)
|
||||
assert.Equal(t, tc.expected, got, tc.name)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -16,7 +16,6 @@ import (
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
@@ -24,6 +23,7 @@ import (
|
||||
|
||||
"github.com/gin-contrib/cors"
|
||||
"github.com/gin-gonic/gin"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
@@ -77,6 +77,7 @@ func isSupportedImageType(image []byte) bool {
|
||||
}
|
||||
|
||||
func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
checkpointStart := time.Now()
|
||||
var req api.GenerateRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
@@ -523,8 +524,8 @@ func checkNameExists(name model.Name) error {
|
||||
}
|
||||
|
||||
func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
var r api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&r); errors.Is(err, io.EOF) {
|
||||
var req api.CreateRequest
|
||||
if err := c.ShouldBindJSON(&req); errors.Is(err, io.EOF) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "missing request body"})
|
||||
return
|
||||
} else if err != nil {
|
||||
@@ -532,7 +533,7 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
name := model.ParseName(cmp.Or(req.Model, req.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
return
|
||||
@@ -543,24 +544,24 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if r.Path == "" && r.Modelfile == "" {
|
||||
if req.Path == "" && req.Modelfile == "" {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "path or modelfile are required"})
|
||||
return
|
||||
}
|
||||
|
||||
var sr io.Reader = strings.NewReader(r.Modelfile)
|
||||
if r.Path != "" && r.Modelfile == "" {
|
||||
f, err := os.Open(r.Path)
|
||||
var r io.Reader = strings.NewReader(req.Modelfile)
|
||||
if req.Path != "" && req.Modelfile == "" {
|
||||
f, err := os.Open(req.Path)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("error reading modelfile: %s", err)})
|
||||
return
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
sr = f
|
||||
r = f
|
||||
}
|
||||
|
||||
f, err := parser.ParseFile(sr)
|
||||
modelfile, err := parser.ParseFile(r)
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -576,13 +577,17 @@ func (s *Server) CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
quantization := cmp.Or(r.Quantize, r.Quantization)
|
||||
if err := CreateModel(ctx, name, filepath.Dir(r.Path), strings.ToUpper(quantization), f, fn); err != nil {
|
||||
quantization := req.Quantization
|
||||
if req.Quantize != "" {
|
||||
quantization = req.Quantize
|
||||
}
|
||||
|
||||
if err := CreateModel(ctx, name.String(), filepath.Dir(req.Path), strings.ToUpper(quantization), modelfile, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
if r.Stream != nil && !*r.Stream {
|
||||
if req.Stream != nil && !*req.Stream {
|
||||
waitForStream(c, ch)
|
||||
return
|
||||
}
|
||||
@@ -616,11 +621,6 @@ func (s *Server) DeleteModelHandler(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err := m.RemoveLayers(); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) ShowModelHandler(c *gin.Context) {
|
||||
@@ -730,7 +730,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
models := []api.ListModelResponse{}
|
||||
models := []api.ModelResponse{}
|
||||
for n, m := range ms {
|
||||
f, err := m.Config.Open()
|
||||
if err != nil {
|
||||
@@ -746,7 +746,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// tag should never be masked
|
||||
models = append(models, api.ListModelResponse{
|
||||
models = append(models, api.ModelResponse{
|
||||
Model: n.DisplayShortest(),
|
||||
Name: n.DisplayShortest(),
|
||||
Size: m.Size(),
|
||||
@@ -762,7 +762,7 @@ func (s *Server) ListModelsHandler(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
slices.SortStableFunc(models, func(i, j api.ListModelResponse) int {
|
||||
slices.SortStableFunc(models, func(i, j api.ModelResponse) int {
|
||||
// most recently modified first
|
||||
return cmp.Compare(j.ModifiedAt.Unix(), i.ModifiedAt.Unix())
|
||||
})
|
||||
@@ -942,7 +942,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
if allowedHost(host) {
|
||||
if c.Request.Method == http.MethodOptions {
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
c.AbortWithStatus(http.StatusNoContent)
|
||||
return
|
||||
}
|
||||
@@ -960,10 +960,6 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
|
||||
openAIProperties := []string{"lang", "package-version", "os", "arch", "runtime", "runtime-version", "async"}
|
||||
for _, prop := range openAIProperties {
|
||||
config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
|
||||
}
|
||||
config.AllowOrigins = envconfig.AllowOrigins
|
||||
|
||||
r := gin.Default()
|
||||
@@ -1143,7 +1139,7 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
}
|
||||
|
||||
func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
models := []api.ProcessModelResponse{}
|
||||
models := []api.ModelResponse{}
|
||||
|
||||
for _, v := range s.sched.loaded {
|
||||
model := v.model
|
||||
@@ -1155,7 +1151,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
QuantizationLevel: model.Config.FileType,
|
||||
}
|
||||
|
||||
mr := api.ProcessModelResponse{
|
||||
mr := api.ModelResponse{
|
||||
Model: model.ShortName,
|
||||
Name: model.ShortName,
|
||||
Size: int64(v.estimatedTotal),
|
||||
@@ -1175,7 +1171,7 @@ func (s *Server) ProcessHandler(c *gin.Context) {
|
||||
models = append(models, mr)
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, api.ProcessResponse{Models: models})
|
||||
c.JSON(http.StatusOK, api.ListResponse{Models: models})
|
||||
}
|
||||
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
@@ -1310,6 +1306,7 @@ func (s *Server) ChatHandler(c *gin.Context) {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
|
||||
resp := api.ChatResponse{
|
||||
Model: req.Model,
|
||||
CreatedAt: time.Now().UTC(),
|
||||
|
||||
@@ -15,13 +15,11 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
var stream bool = false
|
||||
|
||||
func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
|
||||
func createBinFile(t *testing.T) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), "")
|
||||
@@ -30,7 +28,19 @@ func createBinFile(t *testing.T, kv map[string]any, ti []llm.Tensor) string {
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if err := llm.NewGGUFV3(binary.LittleEndian).Encode(f, kv, ti); err != nil {
|
||||
if err := binary.Write(f, binary.LittleEndian, []byte("GGUF")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint32(3)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := binary.Write(f, binary.LittleEndian, uint64(0)); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
@@ -87,12 +97,11 @@ func checkFileExists(t *testing.T, p string, expect []string) {
|
||||
func TestCreateFromBin(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
|
||||
var s Server
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
@@ -113,12 +122,11 @@ func TestCreateFromBin(t *testing.T) {
|
||||
func TestCreateFromModel(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
@@ -150,421 +158,3 @@ func TestCreateFromModel(t *testing.T) {
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateRemovesLayers(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-b507b9c2f6ca642bffcd06665ea7c91f235fd32daeefdf875a0f938db05fb315"),
|
||||
filepath.Join(p, "blobs", "sha256-bc80b03733773e0728011b2f4adf34c458b400e1aad48cb28d61170f3a2ad2d6"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8f2c2167d789c6b2302dff965160fa5029f6a24096d262c1cbb469f21a045382"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
}
|
||||
|
||||
func TestCreateUnsetsSystem(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM Say hi!", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-8585df945d1069bc78b79bd10bb73ba07fbc29b0f5479a31a601c0d12731416e"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-f29e82a8284dbdf5910b1555580ff60b04238b8da9d5e51159ada67a4d0d5851"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nSYSTEM \"\"", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-67d4b8d106af2a5b100a46e9bdc038c71eef2a35c9abac784092654212f97cf5"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"),
|
||||
})
|
||||
|
||||
bts, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(bts) != "" {
|
||||
t.Fatalf("expected empty string, actual %s", string(bts))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateMergeParameters(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nPARAMETER temperature 1\nPARAMETER top_k 10\nPARAMETER stop USER:\nPARAMETER stop ASSISTANT:", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
})
|
||||
|
||||
// in order to merge parameters, the second model must be created FROM the first
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-4cd9d4ba6b734d9b4cbd1e5caa60374c00722e993fce5e1e2d15a33698f71187"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"),
|
||||
})
|
||||
|
||||
actual, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e29a7b3c47287a2489c895d21fe413c20f859a85d20e749492f52a838e36e1ba"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect, err := json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"USER:", "ASSISTANT:"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
|
||||
t.Errorf("expected %s, actual %s", string(expect), string(actual))
|
||||
}
|
||||
|
||||
// slices are replaced
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
Modelfile: "FROM test\nPARAMETER temperature 0.6\nPARAMETER top_p 0.7\nPARAMETER stop <|endoftext|>",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"),
|
||||
filepath.Join(p, "blobs", "sha256-1d0ad71299d48c2fb7ae2b98e683643e771f8a5b72be34942af90d97a91c1e37"),
|
||||
filepath.Join(p, "blobs", "sha256-257aa726584f24970a4f240765e75a7169bfbe7f4966c1f04513d6b6c860583a"),
|
||||
filepath.Join(p, "blobs", "sha256-4a384beaf47a9cbe452dfa5ab70eea691790f3b35a832d12933a1996685bf2b6"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
})
|
||||
|
||||
actual, err = os.ReadFile(filepath.Join(p, "blobs", "sha256-12f58bb75cb3042d69a7e013ab87fb3c3c7088f50ddc62f0c77bd332f0d44d35"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect, err = json.Marshal(map[string]any{"temperature": 0.6, "top_k": 10, "top_p": 0.7, "stop": []string{"<|endoftext|>"}})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(bytes.TrimSpace(expect), bytes.TrimSpace(actual)) {
|
||||
t.Errorf("expected %s, actual %s", string(expect), string(actual))
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateReplacesMessages(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nMESSAGE assistant \"What is my purpose?\"\nMESSAGE user \"You run tests.\"\nMESSAGE assistant \"Oh, my god.\"", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||
})
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
Modelfile: "FROM test\nMESSAGE assistant \"You're a test, Harry.\"\nMESSAGE user \"I-I'm a what?\"\nMESSAGE assistant \"A test. And a thumping good one at that, I'd wager.\"",
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test2", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-298baeaf6928a60cf666d88d64a1ba606feb43a2865687c39e40652e407bffc4"),
|
||||
filepath.Join(p, "blobs", "sha256-4f48b25fe9969564c82f58eb1cedbdff6484cc0baf474bc6c2a9b37c8da3362a"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"),
|
||||
filepath.Join(p, "blobs", "sha256-e0e27d47045063ccb167ae852c51d49a98eab33fabaee4633fdddf97213e40b5"),
|
||||
})
|
||||
|
||||
type message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
f, err := os.Open(filepath.Join(p, "blobs", "sha256-a60ecc9da299ec7ede453f99236e5577fd125e143689b646d9f0ddc9971bf4db"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var actual []message
|
||||
if err := json.NewDecoder(f).Decode(&actual); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expect := []message{
|
||||
{Role: "assistant", Content: "You're a test, Harry."},
|
||||
{Role: "user", Content: "I-I'm a what?"},
|
||||
{Role: "assistant", Content: "A test. And a thumping good one at that, I'd wager."},
|
||||
}
|
||||
|
||||
if !slices.Equal(actual, expect) {
|
||||
t.Errorf("expected %s, actual %s", expect, actual)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateTemplateSystem(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .Prompt }}\nSYSTEM Say hello!\nTEMPLATE {{ .System }} {{ .Prompt }}\nSYSTEM Say bye!", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2b5e330885117c82f3fd75169ea323e141070a2947c11ddb9f79ee0b01c589c1"),
|
||||
filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"),
|
||||
})
|
||||
|
||||
template, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-fe7ac77b725cda2ccad03f88a880ecdfd7a33192d6cae08fce2c0ee1455991ed"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(template) != "{{ .System }} {{ .Prompt }}" {
|
||||
t.Errorf("expected \"{{ .System }} {{ .Prompt }}\", actual %s", template)
|
||||
}
|
||||
|
||||
system, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-4c5f51faac758fecaff8db42f0b7382891a4d0c0bb885f7b86be88c814a7cc86"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(system) != "Say bye!" {
|
||||
t.Errorf("expected \"Say bye!\", actual %s", system)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateLicenses(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nLICENSE MIT\nLICENSE Apache-2.0", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{
|
||||
filepath.Join(p, "manifests", "registry.ollama.ai", "library", "test", "latest"),
|
||||
})
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"),
|
||||
filepath.Join(p, "blobs", "sha256-79a39c37536ddee29cbadd5d5e2dcba8ed7f03e431f626ff38432c1c866bb7e2"),
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"),
|
||||
})
|
||||
|
||||
mit, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-e5dcffe836b6ec8a58e492419b550e65fb8cbdc308503979e5dacb33ac7ea3b7"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(mit) != "MIT" {
|
||||
t.Errorf("expected MIT, actual %s", mit)
|
||||
}
|
||||
|
||||
apache, err := os.ReadFile(filepath.Join(p, "blobs", "sha256-2af71558e438db0b73a20beab92dc278a94e1bbe974c00c1a33e3ab62d53a608"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(apache) != "Apache-2.0" {
|
||||
t.Errorf("expected Apache-2.0, actual %s", apache)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateDetectTemplate(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
var s Server
|
||||
|
||||
t.Run("matched", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, llm.KV{
|
||||
"tokenizer.chat_template": "{{ bos_token }}{% for message in messages %}{{'<|' + message['role'] + '|>' + '\n' + message['content'] + '<|end|>\n' }}{% endfor %}{% if add_generation_prompt %}{{ '<|assistant|>\n' }}{% else %}{{ eos_token }}{% endif %}",
|
||||
}, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-2f8e594e6f34b1b4d36a246628eeb3365ce442303d656f1fcc69e821722acea0"),
|
||||
filepath.Join(p, "blobs", "sha256-542b217f179c7825eeb5bca3c77d2b75ed05bafbd3451d9188891a60a85337c6"),
|
||||
filepath.Join(p, "blobs", "sha256-553c4a3f747b3d22a4946875f1cc8ed011c2930d83f864a0c7265f9ec0a20413"),
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("unmatched", func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{
|
||||
filepath.Join(p, "blobs", "sha256-a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99"),
|
||||
filepath.Join(p, "blobs", "sha256-ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116"),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,28 +1,22 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
)
|
||||
|
||||
func TestDelete(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
envconfig.LoadConfig()
|
||||
|
||||
var s Server
|
||||
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test",
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -31,7 +25,7 @@ func TestDelete(t *testing.T) {
|
||||
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: "test2",
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s\nTEMPLATE {{ .System }} {{ .Prompt }}", createBinFile(t)),
|
||||
})
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
@@ -75,33 +69,3 @@ func TestDelete(t *testing.T) {
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
checkFileExists(t, filepath.Join(p, "blobs", "*"), []string{})
|
||||
}
|
||||
|
||||
func TestDeleteDuplicateLayers(t *testing.T) {
|
||||
p := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", p)
|
||||
var s Server
|
||||
|
||||
n := model.ParseName("test")
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(&ConfigV2{}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
config, err := NewLayer(&b, "application/vnd.docker.container.image.v1+json")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// create a manifest with duplicate layers
|
||||
if err := WriteManifest(n, config, []*Layer{config}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
w := createRequest(t, s.DeleteModelHandler, api.DeleteRequest{Name: "test"})
|
||||
if w.Code != http.StatusOK {
|
||||
t.Errorf("expected status code 200, actual %d", w.Code)
|
||||
}
|
||||
|
||||
checkFileExists(t, filepath.Join(p, "manifests", "*", "*", "*", "*"), []string{})
|
||||
}
|
||||
|
||||
@@ -8,12 +8,10 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
)
|
||||
|
||||
func TestList(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
envconfig.LoadConfig()
|
||||
|
||||
expectNames := []string{
|
||||
"mistral:7b-instruct-q4_0",
|
||||
@@ -31,7 +29,7 @@ func TestList(t *testing.T) {
|
||||
for _, n := range expectNames {
|
||||
createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: n,
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -15,12 +15,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
@@ -28,20 +25,20 @@ func createTestFile(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), name)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
return f.Name()
|
||||
}
|
||||
@@ -56,18 +53,16 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
|
||||
createTestModel := func(t *testing.T, name string) {
|
||||
t.Helper()
|
||||
|
||||
fname := createTestFile(t, "ollama-model")
|
||||
|
||||
r := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
|
||||
modelfile, err := parser.ParseFile(r)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
t.Logf("Status: %s", resp.Status)
|
||||
}
|
||||
err = CreateModel(context.TODO(), model.ParseName(name), "", "", modelfile, fn)
|
||||
require.NoError(t, err)
|
||||
err = CreateModel(context.TODO(), name, "", "", modelfile, fn)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
testCases := []testCase{
|
||||
@@ -79,9 +74,9 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, fmt.Sprintf(`{"version":"%s"}`, version.Version), string(body))
|
||||
},
|
||||
},
|
||||
@@ -91,17 +86,17 @@ func Test_Routes(t *testing.T) {
|
||||
Path: "/api/tags",
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var modelList api.ListResponse
|
||||
|
||||
err = json.Unmarshal(body, &modelList)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.NotNil(t, modelList.Models)
|
||||
assert.Empty(t, len(modelList.Models))
|
||||
assert.Equal(t, 0, len(modelList.Models))
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -113,18 +108,16 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotContains(t, string(body), "expires_at")
|
||||
assert.Nil(t, err)
|
||||
|
||||
var modelList api.ListResponse
|
||||
err = json.Unmarshal(body, &modelList)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
assert.Len(t, modelList.Models, 1)
|
||||
assert.Equal(t, "test-model:latest", modelList.Models[0].Name)
|
||||
assert.Equal(t, 1, len(modelList.Models))
|
||||
assert.Equal(t, modelList.Models[0].Name, "test-model:latest")
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -141,7 +134,7 @@ func Test_Routes(t *testing.T) {
|
||||
Stream: &stream,
|
||||
}
|
||||
jsonData, err := json.Marshal(createReq)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
@@ -149,11 +142,11 @@ func Test_Routes(t *testing.T) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json", contentType)
|
||||
_, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 200, resp.StatusCode)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, resp.StatusCode, 200)
|
||||
|
||||
model, err := GetModel("t-bone")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "t-bone:latest", model.ShortName)
|
||||
},
|
||||
},
|
||||
@@ -168,13 +161,13 @@ func Test_Routes(t *testing.T) {
|
||||
Destination: "beefsteak",
|
||||
}
|
||||
jsonData, err := json.Marshal(copyReq)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
model, err := GetModel("beefsteak")
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, "beefsteak:latest", model.ShortName)
|
||||
},
|
||||
},
|
||||
@@ -186,18 +179,18 @@ func Test_Routes(t *testing.T) {
|
||||
createTestModel(t, "show-model")
|
||||
showReq := api.ShowRequest{Model: "show-model"}
|
||||
jsonData, err := json.Marshal(showReq)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
req.Body = io.NopCloser(bytes.NewReader(jsonData))
|
||||
},
|
||||
Expected: func(t *testing.T, resp *http.Response) {
|
||||
contentType := resp.Header.Get("Content-Type")
|
||||
assert.Equal(t, "application/json; charset=utf-8", contentType)
|
||||
assert.Equal(t, contentType, "application/json; charset=utf-8")
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var showResp api.ShowResponse
|
||||
err = json.Unmarshal(body, &showResp)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
var params []string
|
||||
paramsSplit := strings.Split(showResp.Parameters, "\n")
|
||||
@@ -217,7 +210,6 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
envconfig.LoadConfig()
|
||||
|
||||
s := &Server{}
|
||||
router := s.GenerateRoutes()
|
||||
@@ -229,14 +221,14 @@ func Test_Routes(t *testing.T) {
|
||||
t.Run(tc.Name, func(t *testing.T) {
|
||||
u := httpSrv.URL + tc.Path
|
||||
req, err := http.NewRequestWithContext(context.TODO(), tc.Method, u, nil)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
|
||||
if tc.Setup != nil {
|
||||
tc.Setup(t, req)
|
||||
}
|
||||
|
||||
resp, err := httpSrv.Client().Do(req)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
defer resp.Body.Close()
|
||||
|
||||
if tc.Expected != nil {
|
||||
@@ -248,7 +240,6 @@ func Test_Routes(t *testing.T) {
|
||||
|
||||
func TestCase(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
envconfig.LoadConfig()
|
||||
|
||||
cases := []string{
|
||||
"mistral",
|
||||
@@ -264,7 +255,7 @@ func TestCase(t *testing.T) {
|
||||
t.Run(tt, func(t *testing.T) {
|
||||
w := createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: tt,
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
@@ -280,7 +271,7 @@ func TestCase(t *testing.T) {
|
||||
t.Run("create", func(t *testing.T) {
|
||||
w = createRequest(t, s.CreateModelHandler, api.CreateRequest{
|
||||
Name: strings.ToUpper(tt),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t, nil, nil)),
|
||||
Modelfile: fmt.Sprintf("FROM %s", createBinFile(t)),
|
||||
Stream: &stream,
|
||||
})
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user