Compare commits

..

1 Commits

Author SHA1 Message Date
Bruce MacDonald
057cc54b66 benchmark: compare backend graph computation times
Track execution time of individual tensor operations (views, copies, reshapes etc)
during LLM forward passes using CGo bindings to the native graph runtime. This
helps identify performance bottlenecks in the computation graph and optimize memory
operations that can significantly impact inference latency.
2025-02-19 15:22:53 -08:00
22 changed files with 411 additions and 589 deletions

View File

@@ -160,10 +160,6 @@ jobs:
echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append echo "$hipPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CC=$hipPath\bin\clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CXX=clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }}
uses: actions/cache/save@v4 uses: actions/cache/save@v4
with: with:

View File

@@ -384,8 +384,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally) - [OpenDeepResearcher-via-searxng](https://github.com/benhaotang/OpenDeepResearcher-via-searxng) (A Deep Research equivent endpoint with Ollama support for running locally)
- [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot) - [AntSK](https://github.com/AIDotNet/AntSK) (Out-of-the-box & Adaptable RAG Chatbot)
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot) - [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
### Cloud ### Cloud

View File

@@ -132,7 +132,7 @@ func (c *Client) do(ctx context.Context, method, path string, reqData, respData
const maxBufferSize = 512 * format.KiloByte const maxBufferSize = 512 * format.KiloByte
func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error { func (c *Client) stream(ctx context.Context, method, path string, data any, fn func([]byte) error) error {
var buf io.Reader var buf *bytes.Buffer
if data != nil { if data != nil {
bts, err := json.Marshal(data) bts, err := json.Marshal(data)
if err != nil { if err != nil {

View File

@@ -1,13 +1,6 @@
package api package api
import ( import (
"context"
"encoding/json"
"fmt"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing" "testing"
) )
@@ -50,206 +43,3 @@ func TestClientFromEnvironment(t *testing.T) {
}) })
} }
} }
// testError represents an internal error type with status code and message
// this is used since the error response from the server is not a standard error struct
type testError struct {
message string
statusCode int
}
func (e testError) Error() string {
return e.message
}
func TestClientStream(t *testing.T) {
testCases := []struct {
name string
responses []any
wantErr string
}{
{
name: "immediate error response",
responses: []any{
testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
},
wantErr: "test error message",
},
{
name: "error after successful chunks, ok response",
responses: []any{
ChatResponse{Message: Message{Content: "partial response 1"}},
ChatResponse{Message: Message{Content: "partial response 2"}},
testError{
message: "mid-stream error",
statusCode: http.StatusOK,
},
},
wantErr: "mid-stream error",
},
{
name: "successful stream completion",
responses: []any{
ChatResponse{Message: Message{Content: "chunk 1"}},
ChatResponse{Message: Message{Content: "chunk 2"}},
ChatResponse{
Message: Message{Content: "final chunk"},
Done: true,
DoneReason: "stop",
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
flusher, ok := w.(http.Flusher)
if !ok {
t.Fatal("expected http.Flusher")
}
w.Header().Set("Content-Type", "application/x-ndjson")
for _, resp := range tc.responses {
if errResp, ok := resp.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
flusher.Flush()
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var receivedChunks []ChatResponse
err := client.stream(context.Background(), http.MethodPost, "/v1/chat", nil, func(chunk []byte) error {
var resp ChatResponse
if err := json.Unmarshal(chunk, &resp); err != nil {
return fmt.Errorf("failed to unmarshal chunk: %w", err)
}
receivedChunks = append(receivedChunks, resp)
return nil
})
if tc.wantErr != "" {
if err == nil {
t.Fatal("expected error but got nil")
}
if !strings.Contains(err.Error(), tc.wantErr) {
t.Errorf("expected error containing %q, got %v", tc.wantErr, err)
}
return
}
if err != nil {
t.Errorf("unexpected error: %v", err)
}
})
}
}
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
}{
{
name: "immediate error response",
response: testError{
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
},
{
name: "server error response",
response: testError{
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
},
{
name: "successful response",
response: struct {
ID string `json:"id"`
Success bool `json:"success"`
}{
ID: "msg_123",
Success: true,
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(tc.response); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
}))
defer ts.Close()
client := NewClient(&url.URL{Scheme: "http", Host: ts.Listener.Addr().String()}, http.DefaultClient)
var resp struct {
ID string `json:"id"`
Success bool `json:"success"`
}
err := client.do(context.Background(), http.MethodPost, "/v1/messages", nil, &resp)
if tc.wantErr != "" {
if err == nil {
t.Fatalf("got nil, want error %q", tc.wantErr)
}
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
return
}
if err != nil {
t.Fatalf("got error %q, want nil", err)
}
if expectedResp, ok := tc.response.(struct {
ID string `json:"id"`
Success bool `json:"success"`
}); ok {
if resp.ID != expectedResp.ID {
t.Errorf("response ID mismatch: got %q, want %q", resp.ID, expectedResp.ID)
}
if resp.Success != expectedResp.Success {
t.Errorf("response Success mismatch: got %v, want %v", resp.Success, expectedResp.Success)
}
}
})
}
}

View File

@@ -0,0 +1,86 @@
package backend
import (
"flag"
"fmt"
"io"
"log"
"os"
"testing"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/server"
_ "github.com/ollama/ollama/model/models/llama"
)
var modelName = flag.String("m", "", "Name of the model to benchmark")
func suppressOutput() (cleanup func()) {
oldStdout, oldStderr := os.Stdout, os.Stderr
os.Stdout, os.Stderr = nil, nil
log.SetOutput(io.Discard)
return func() {
os.Stdout, os.Stderr = oldStdout, oldStderr
log.SetOutput(os.Stderr)
}
}
func setupModel(b *testing.B) model.Model {
if *modelName == "" {
b.Fatal("Error: -m flag is required for benchmark tests")
}
sm, err := server.GetModel(*modelName)
if err != nil {
b.Fatal(err)
}
m, err := model.New(sm.ModelPath)
if err != nil {
b.Fatal(err)
}
m.Config().Cache.Init(m.Backend(), ml.DTypeF32, 2048)
return m
}
func BenchmarkGGMLOperations(b *testing.B) {
// loading the GGML back-end logs to standard out and makes the bench output messy
cleanup := suppressOutput()
defer cleanup()
b.Setenv("OLLAMA_BENCHMARK", "1")
b.Setenv("OLLAMA_BACKEND", "ggml")
m := setupModel(b)
// Sample input data
inputIDs := []int32{1, 2, 3, 4, 5}
options := model.Options{
Inputs: inputIDs,
Positions: []int32{1, 2, 3, 4, 5},
Sequences: []int{1, 1, 1, 1, 1},
Outputs: []int32{int32(len(inputIDs) - 1)},
}
b.ResetTimer()
for range b.N {
ctx := m.Backend().NewContext()
defer ctx.Close()
modelOutput, err := model.Forward(ctx, m, options)
if err != nil {
b.Fatal(fmt.Errorf("forward pass failed: %v", err))
}
ctx.Compute(modelOutput)
for _, op := range ctx.Timing() {
b.ReportMetric(op.Duration, fmt.Sprintf("%s_ms", op.Type))
}
}
}

View File

@@ -46,6 +46,15 @@ Install prerequisites:
- (Optional) NVIDIA GPU support - (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network) - [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.
> [!IMPORTANT]
> ROCm is not compatible with Visual Studio CMake generators. Use `-GNinja` when configuring the project.
> [!IMPORTANT]
> CUDA is only compatible with Visual Studio CMake generators.
Then, configure and build the project: Then, configure and build the project:
```shell ```shell
@@ -53,14 +62,6 @@ cmake -B build
cmake --build build --config Release cmake --build build --config Release
``` ```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
> cmake -B build -G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++
> cmake --build build --config Release
> ```
Lastly, run Ollama: Lastly, run Ollama:
```shell ```shell

View File

@@ -53,8 +53,8 @@ func Host() *url.URL {
} }
} }
// AllowedOrigins returns a list of allowed origins. AllowedOrigins can be configured via the OLLAMA_ORIGINS environment variable. // Origins returns a list of allowed origins. Origins can be configured via the OLLAMA_ORIGINS environment variable.
func AllowedOrigins() (origins []string) { func Origins() (origins []string) {
if s := Var("OLLAMA_ORIGINS"); s != "" { if s := Var("OLLAMA_ORIGINS"); s != "" {
origins = strings.Split(s, ",") origins = strings.Split(s, ",")
} }
@@ -167,6 +167,8 @@ var (
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE") MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine // Enable the new Ollama engine
NewEngine = Bool("OLLAMA_NEW_ENGINE") NewEngine = Bool("OLLAMA_NEW_ENGINE")
// Ollama is running in a benchmark context, additional timing data will be collected.
Benchmark = Bool("OLLAMA_BENCHMARK")
) )
func String(s string) func() string { func String(s string) func() string {
@@ -249,7 +251,7 @@ func AsMap() map[string]EnvVar {
"OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"}, "OLLAMA_NOHISTORY": {"OLLAMA_NOHISTORY", NoHistory(), "Do not preserve readline history"},
"OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"}, "OLLAMA_NOPRUNE": {"OLLAMA_NOPRUNE", NoPrune(), "Do not prune model blobs on startup"},
"OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"}, "OLLAMA_NUM_PARALLEL": {"OLLAMA_NUM_PARALLEL", NumParallel(), "Maximum number of parallel requests"},
"OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", AllowedOrigins(), "A comma separated list of allowed origins"}, "OLLAMA_ORIGINS": {"OLLAMA_ORIGINS", Origins(), "A comma separated list of allowed origins"},
"OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"}, "OLLAMA_SCHED_SPREAD": {"OLLAMA_SCHED_SPREAD", SchedSpread(), "Always schedule model across all GPUs"},
"OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"},
"OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"},

View File

@@ -134,7 +134,7 @@ func TestOrigins(t *testing.T) {
t.Run(tt.value, func(t *testing.T) { t.Run(tt.value, func(t *testing.T) {
t.Setenv("OLLAMA_ORIGINS", tt.value) t.Setenv("OLLAMA_ORIGINS", tt.value)
if diff := cmp.Diff(AllowedOrigins(), tt.expect); diff != "" { if diff := cmp.Diff(Origins(), tt.expect); diff != "" {
t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff) t.Errorf("%s: mismatch (-want +got):\n%s", tt.value, diff)
} }
}) })

View File

@@ -352,6 +352,10 @@ func (c *testContext) MaxTensors() int {
return 10 return 10
} }
func (c *testContext) Timing() []ml.OpTiming {
return []ml.OpTiming{}
}
func (c *testContext) Close() {} func (c *testContext) Close() {}
type testTensor struct { type testTensor struct {

View File

@@ -4,23 +4,17 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
Subject: [PATCH] use std::filesystem::path instead of wstring Subject: [PATCH] use std::filesystem::path instead of wstring
--- ---
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++-------------------- ggml/src/ggml-backend-reg.cpp | 116 ++++++++++++----------------------
1 file changed, 58 insertions(+), 86 deletions(-) 1 file changed, 40 insertions(+), 76 deletions(-)
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
index 84b21dd8..e35a6936 100644 index 84b21dd8..de78feae 100644
--- a/ggml/src/ggml-backend-reg.cpp --- a/ggml/src/ggml-backend-reg.cpp
+++ b/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp
@@ -66,26 +66,6 @@ @@ -72,16 +72,6 @@
#include "ggml-kompute.h" # pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif #endif
-// disable C++17 deprecation warning for std::codecvt_utf8
-#if defined(__clang__)
-# pragma clang diagnostic push
-# pragma clang diagnostic ignored "-Wdeprecated-declarations"
-#endif
-
-static std::wstring utf8_to_utf16(const std::string & str) { -static std::wstring utf8_to_utf16(const std::string & str) {
- std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter; - std::wstring_convert<std::codecvt_utf8_utf16<wchar_t>> converter;
- return converter.from_bytes(str); - return converter.from_bytes(str);
@@ -31,14 +25,10 @@ index 84b21dd8..e35a6936 100644
- return converter.to_bytes(str); - return converter.to_bytes(str);
-} -}
- -
-#if defined(__clang__) #if defined(__clang__)
-# pragma clang diagnostic pop # pragma clang diagnostic pop
-#endif #endif
- @@ -96,12 +86,12 @@ struct dl_handle_deleter {
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -96,7 +76,7 @@ struct dl_handle_deleter {
} }
}; };
@@ -47,44 +37,24 @@ index 84b21dd8..e35a6936 100644
// suppress error dialogs for missing DLLs // suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
@@ -129,8 +109,8 @@ struct dl_handle_deleter {
- HMODULE handle = LoadLibraryW(path.c_str());
+ HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
@@ -129,8 +119,8 @@ struct dl_handle_deleter {
} }
}; };
-static void * dl_load_library(const std::wstring & path) { -static void * dl_load_library(const std::wstring & path) {
- dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL); - dl_handle * handle = dlopen(utf16_to_utf8(path).c_str(), RTLD_NOW | RTLD_LOCAL);
+static void * dl_load_library(const std::filesystem::path & path) { +static void * dl_load_library(const std::filesystem::path & path) {
+ dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); + dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle; return handle;
} }
@@ -141,6 +121,25 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { @@ -222,11 +212,11 @@ struct ggml_backend_registry {
#endif
+static std::string path_to_string(const std::filesystem::path & path)
+{
+#ifdef _WIN32
+ const std::wstring wstr = path.wstring();
+ const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
+ if (size_needed <= 0) {
+ return std::string();
+ }
+
+ // size_needed includes the null terminator
+ std::string str(size_needed - 1, '\0');
+ WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
+ return str;
+#else
+ return path.string();
+#endif
+}
+
+
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry {
@@ -222,11 +221,11 @@ struct ggml_backend_registry {
); );
} }
@@ -94,49 +64,49 @@ index 84b21dd8..e35a6936 100644
if (!handle) { if (!handle) {
if (!silent) { if (!silent) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str()); - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -234,7 +233,7 @@ struct ggml_backend_registry { @@ -234,7 +224,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) { if (score_fn && score_fn() == 0) {
if (!silent) { if (!silent) {
- GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str()); - GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str()); + GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -242,7 +241,7 @@ struct ggml_backend_registry { @@ -242,7 +232,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) { if (!backend_init_fn) {
if (!silent) { if (!silent) {
- GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str()); - GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str()); + GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -251,16 +250,16 @@ struct ggml_backend_registry { @@ -251,16 +241,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) { if (!silent) {
if (!reg) { if (!reg) {
- GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str()); - GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, utf16_to_utf8(path).c_str());
+ GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str()); + GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str());
} else { } else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
- __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); - __func__, utf16_to_utf8(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
+ __func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); + __func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
} }
} }
return nullptr; return nullptr;
} }
- GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str()); - GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), utf16_to_utf8(path).c_str());
+ GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str()); + GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle)); register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -396,14 +395,14 @@ ggml_backend_t ggml_backend_init_best(void) { @@ -396,14 +386,14 @@ ggml_backend_t ggml_backend_init_best(void) {
// Dynamic loading // Dynamic loading
ggml_backend_reg_t ggml_backend_load(const char * path) { ggml_backend_reg_t ggml_backend_load(const char * path) {
@@ -153,7 +123,7 @@ index 84b21dd8..e35a6936 100644
#if defined(__APPLE__) #if defined(__APPLE__)
// get executable path // get executable path
std::vector<char> path; std::vector<char> path;
@@ -415,15 +414,9 @@ static std::wstring get_executable_path() { @@ -415,15 +405,9 @@ static std::wstring get_executable_path() {
} }
path.resize(size); path.resize(size);
} }
@@ -171,7 +141,7 @@ index 84b21dd8..e35a6936 100644
std::vector<char> path(1024); std::vector<char> path(1024);
while (true) { while (true) {
// get executable path // get executable path
@@ -436,76 +429,55 @@ static std::wstring get_executable_path() { @@ -436,76 +420,56 @@ static std::wstring get_executable_path() {
break; break;
} }
if (len < (ssize_t) path.size()) { if (len < (ssize_t) path.size()) {
@@ -209,11 +179,11 @@ index 84b21dd8..e35a6936 100644
-static std::wstring backend_filename_prefix() { -static std::wstring backend_filename_prefix() {
-#ifdef _WIN32 -#ifdef _WIN32
- return L"ggml-"; - return L"ggml-";
-#else
- return L"libggml-";
+ return std::filesystem::path(path.data()).parent_path(); + return std::filesystem::path(path.data()).parent_path();
#endif #else
- return L"libggml-";
+ return {}; + return {};
#endif
} }
-static std::wstring backend_filename_suffix() { -static std::wstring backend_filename_suffix() {
@@ -264,7 +234,7 @@ index 84b21dd8..e35a6936 100644
for (const auto & search_path : search_paths) { for (const auto & search_path : search_paths) {
if (!fs::exists(search_path)) { if (!fs::exists(search_path)) {
continue; continue;
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, @@ -514,31 +478,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
for (const auto & entry : dir_it) { for (const auto & entry : dir_it) {
try { try {
if (entry.is_regular_file()) { if (entry.is_regular_file()) {
@@ -277,20 +247,20 @@ index 84b21dd8..e35a6936 100644
+ dl_handle_ptr handle { dl_load_library(entry.path()) }; + dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) { if (!handle) {
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
continue; continue;
} }
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) { if (!score_fn) {
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); - GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
continue; continue;
} }
int s = score_fn(); int s = score_fn();
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
if (s > best_score) { if (s > best_score) {
best_score = s; best_score = s;
- best_path = entry.path().wstring(); - best_path = entry.path().wstring();
@@ -300,11 +270,11 @@ index 84b21dd8..e35a6936 100644
} }
} catch (const std::exception & e) { } catch (const std::exception & e) {
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); - GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what());
} }
} }
} }
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, @@ -546,7 +510,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (best_score == 0) { if (best_score == 0) {
// try to load the base backend // try to load the base backend
for (const auto & search_path : search_paths) { for (const auto & search_path : search_paths) {

View File

@@ -2,6 +2,7 @@ package ml
import ( import (
"bytes" "bytes"
"cmp"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"os" "os"
@@ -26,24 +27,9 @@ type Backend interface {
SystemInfo() string SystemInfo() string
} }
// BackendParams controls how the backend loads and executes models var backends = make(map[string]func(*os.File) (Backend, error))
type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU
NumThreads int
// MainGPU is the index of the primary GPU to use func RegisterBackend(name string, f func(*os.File) (Backend, error)) {
MainGPU int
// NumGPULayers is the number of layers to offload to GPUs
NumGPULayers int
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, error)) {
if _, ok := backends[name]; ok { if _, ok := backends[name]; ok {
panic("backend: backend already registered") panic("backend: backend already registered")
} }
@@ -51,9 +37,9 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
backends[name] = f backends[name] = f
} }
func NewBackend(f *os.File, params BackendParams) (Backend, error) { func NewBackend(f *os.File) (Backend, error) {
if backend, ok := backends["ggml"]; ok { if backend, ok := backends[cmp.Or(os.Getenv("OLLAMA_BACKEND"), "ggml")]; ok {
return backend(f, params) return backend(f)
} }
return nil, fmt.Errorf("unsupported backend") return nil, fmt.Errorf("unsupported backend")
@@ -68,6 +54,30 @@ type Context interface {
Compute(...Tensor) Compute(...Tensor)
MaxTensors() int MaxTensors() int
Close() Close()
Timing() []OpTiming
}
// OpType is the type of operation performed during a forward pass.
type OpType string
const (
View OpType = "View"
Copy OpType = "Copy"
Reshape OpType = "Reshape"
Permute OpType = "Permute"
Contiguous OpType = "Contiguous"
Input OpType = "Input"
ComputeOp OpType = "Compute"
Transpose OpType = "Transpose"
)
// OpTiming stores the timing information for a single operation.
type OpTiming struct {
Type OpType
Operation string
Duration float64
Order int
} }
type Tensor interface { type Tensor interface {
@@ -111,26 +121,6 @@ type Tensor interface {
Copy(ctx Context, t2 Tensor) Tensor Copy(ctx Context, t2 Tensor) Tensor
} }
// ScaledDotProductAttention implements a fused attention
// operation equivalent to following code on a tensor named
// query:
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
//
// if mask != nil {
// kq = kq.Add(ctx, mask)
// }
//
// kq = kq.Softmax(ctx)
//
// kqv := value.Mulmat(ctx, kq)
// return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
type ScaledDotProductAttention interface {
ScaledDotProductAttention(ctx Context, key, value, mask Tensor, scale float64) Tensor
}
type number interface { type number interface {
~int | ~int8 | ~int16 | ~int32 | ~int64 | ~int | ~int8 | ~int16 | ~int32 | ~int64 |
~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 | ~uint | ~uint8 | ~uint16 | ~uint32 | ~uint64 |

View File

@@ -4,6 +4,8 @@ package ggml
#cgo CPPFLAGS: -I${SRCDIR}/ggml/include #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
#include <stdlib.h> #include <stdlib.h>
#include <stdint.h> #include <stdint.h>
#include <time.h>
#include <string.h>
#include "ggml.h" #include "ggml.h"
#include "ggml-cpu.h" #include "ggml-cpu.h"
#include "ggml-backend.h" #include "ggml-backend.h"
@@ -21,6 +23,54 @@ COMPILER inline get_compiler() {
#endif #endif
} }
// Define a fixed-size struct to store timing data
#define MAX_TENSOR_NAME 256
#define MAX_TIMINGS 1000
typedef struct {
char tensor_name[MAX_TENSOR_NAME];
double duration_ms;
} timing_entry;
typedef struct {
timing_entry entries[MAX_TIMINGS];
int count;
} timing_data;
// Global timing data structure
timing_data g_timings = {0};
double get_time_ms() {
struct timespec ts;
clock_gettime(CLOCK_MONOTONIC, &ts);
return ts.tv_sec * 1000.0 + ts.tv_nsec / 1000000.0;
}
bool debug_callback(struct ggml_tensor * t, bool ask, void * user_data) {
static double start_time;
static char current_tensor[MAX_TENSOR_NAME];
if (ask) {
start_time = get_time_ms();
strncpy(current_tensor, t->name, MAX_TENSOR_NAME - 1);
current_tensor[MAX_TENSOR_NAME - 1] = '\0';
} else {
double end_time = get_time_ms();
double duration = end_time - start_time;
if (g_timings.count < MAX_TIMINGS) {
strncpy(g_timings.entries[g_timings.count].tensor_name, current_tensor, MAX_TENSOR_NAME - 1);
g_timings.entries[g_timings.count].duration_ms = duration;
g_timings.count++;
}
}
return true;
}
void clear_timings() {
g_timings.count = 0;
}
*/ */
import "C" import "C"
@@ -29,9 +79,11 @@ import (
"io" "io"
"log/slog" "log/slog"
"os" "os"
"strings"
"sync" "sync"
"unsafe" "unsafe"
"github.com/ollama/ollama/envconfig"
"github.com/ollama/ollama/format" "github.com/ollama/ollama/format"
fs "github.com/ollama/ollama/fs/ggml" fs "github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml"
@@ -82,11 +134,9 @@ type Backend struct {
meta *fs.GGML meta *fs.GGML
cpus, gpus []Context cpus, gpus []Context
tensors map[string]*Context tensors map[string]*Context
sched *C.struct_ggml_backend_sched
} }
func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { func New(r *os.File) (ml.Backend, error) {
meta, n, err := fs.Decode(r, -1) meta, n, err := fs.Decode(r, -1)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -184,24 +234,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
return nil, err return nil, err
} }
backends := make([]*C.struct_ggml_backend, len(gpus)+len(cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(gpus)+len(cpus))
for i, c := range append(gpus, cpus...) {
backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
}
return &Backend{ return &Backend{
meta: meta, meta: meta,
cpus: cpus, cpus: cpus,
gpus: gpus, gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(max(8192, len(meta.Tensors().Items())*5)),
true,
),
}, nil }, nil
} }
@@ -235,23 +271,31 @@ func (b *Backend) NewContext() ml.Context {
}) })
backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus)) backends := make([]*C.struct_ggml_backend, len(b.gpus)+len(b.cpus))
bufts := make([]*C.struct_ggml_backend_buffer_type, len(b.gpus)+len(b.cpus))
for i, c := range append(b.gpus, b.cpus...) { for i, c := range append(b.gpus, b.cpus...) {
backends[i] = c.backend backends[i] = c.backend
bufts[i] = C.ggml_backend_get_default_buffer_type(c.backend)
} }
return &Context{ return &Context{
b: b,
ctx: c, ctx: c,
backend: backends[0], backend: backends[0],
nodes: nodes, nodes: nodes,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
C.int(len(backends)),
C.size_t(nodes),
true,
),
} }
} }
type Context struct { type Context struct {
b *Backend
ctx *C.struct_ggml_context ctx *C.struct_ggml_context
backend *C.struct_ggml_backend backend *C.struct_ggml_backend
sched *C.struct_ggml_backend_sched
graph *C.struct_ggml_cgraph graph *C.struct_ggml_cgraph
nodes int nodes int
} }
@@ -264,14 +308,68 @@ func (c *Context) Forward(t ml.Tensor) {
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t) C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
} }
// Timing retrieves the collected timing data
func (c *Context) Timing() []ml.OpTiming {
sequence := make([]ml.OpTiming, C.g_timings.count)
for i := range int(C.g_timings.count) {
entry := C.g_timings.entries[i]
tensorName := C.GoString(&entry.tensor_name[0])
// Determine operation type and description based on tensor name
var opType ml.OpType
var opDesc string
switch {
case strings.Contains(tensorName, "(view)"):
opType, opDesc = ml.View, "Memory view"
case strings.Contains(tensorName, "(copy)") || strings.Contains(tensorName, "(copy of"):
opType, opDesc = ml.Copy, "Memory copy"
case strings.Contains(tensorName, "(reshaped)"):
opType, opDesc = ml.Reshape, "Reshape"
case strings.Contains(tensorName, "(permuted)"):
opType, opDesc = ml.Permute, "Permute dimensions"
case strings.Contains(tensorName, "(cont)"):
opType, opDesc = ml.Contiguous, "Make contiguous"
case strings.Contains(tensorName, "(transposed)"):
opType, opDesc = ml.Transpose, "Transpose"
case strings.HasPrefix(tensorName, "leaf_"):
opType, opDesc = ml.Input, fmt.Sprintf("Input tensor %s", tensorName)
case strings.HasPrefix(tensorName, "node_"):
opType, opDesc = ml.ComputeOp, fmt.Sprintf("Computation %s", tensorName)
default:
opType, opDesc = "Unknown", tensorName
}
sequence[i] = ml.OpTiming{
Type: opType,
Operation: opDesc,
Duration: float64(entry.duration_ms),
Order: i,
}
}
return sequence
}
func (c *Context) Compute(tensors ...ml.Tensor) { func (c *Context) Compute(tensors ...ml.Tensor) {
C.ggml_backend_sched_graph_compute_async(c.b.sched, c.graph) if envconfig.Benchmark() {
C.ggml_backend_sched_reset(c.b.sched) // Clear previous timings before new computation
C.clear_timings()
C.ggml_backend_sched_set_eval_callback(
c.sched,
C.ggml_backend_eval_callback(C.debug_callback),
nil,
)
}
C.ggml_backend_sched_graph_compute_async(c.sched, c.graph)
needSync := true needSync := true
sync := func() { sync := func() {
if needSync { if needSync {
C.ggml_backend_sched_synchronize(c.b.sched) C.ggml_backend_sched_synchronize(c.sched)
needSync = false needSync = false
} }
} }
@@ -359,6 +457,7 @@ func (c Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
func (c *Context) Close() { func (c *Context) Close() {
if c != nil { if c != nil {
C.ggml_backend_sched_free(c.sched)
C.ggml_free(c.ctx) C.ggml_free(c.ctx)
} }
} }
@@ -485,7 +584,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
} }
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w) return (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
} }
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor { func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -651,21 +750,6 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int
} }
} }
func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.Tensor, scale float64) ml.Tensor {
var kqMask *C.struct_ggml_tensor
if mask != nil {
kqMask = mask.(*Tensor).t
}
kq := key.MulmatFullPrec(ctx, t)
kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
func (b *Backend) SystemInfo() string { func (b *Backend) SystemInfo() string {
var compiler string var compiler string
switch C.get_compiler() { switch C.get_compiler() {

View File

@@ -66,6 +66,16 @@
#include "ggml-kompute.h" #include "ggml-kompute.h"
#endif #endif
// disable C++17 deprecation warning for std::codecvt_utf8
#if defined(__clang__)
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#endif
#if defined(__clang__)
# pragma clang diagnostic pop
#endif
#ifdef _WIN32 #ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>; using dl_handle = std::remove_pointer_t<HMODULE>;
@@ -81,7 +91,7 @@ static dl_handle * dl_load_library(const std::filesystem::path & path) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS); DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS); SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.c_str()); HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode); SetErrorMode(old_mode);
@@ -110,7 +120,7 @@ struct dl_handle_deleter {
}; };
static void * dl_load_library(const std::filesystem::path & path) { static void * dl_load_library(const std::filesystem::path & path) {
dl_handle * handle = dlopen(path.c_str(), RTLD_NOW | RTLD_LOCAL); dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle; return handle;
} }
@@ -121,25 +131,6 @@ static void * dl_get_sym(dl_handle * handle, const char * name) {
#endif #endif
static std::string path_to_string(const std::filesystem::path & path)
{
#ifdef _WIN32
const std::wstring wstr = path.wstring();
const int size_needed = WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, nullptr, 0, nullptr, nullptr);
if (size_needed <= 0) {
return std::string();
}
// size_needed includes the null terminator
std::string str(size_needed - 1, '\0');
WideCharToMultiByte(CP_UTF8, 0, wstr.c_str(), -1, str.data(), size_needed, nullptr, nullptr);
return str;
#else
return path.string();
#endif
}
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>; using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry { struct ggml_backend_reg_entry {
@@ -225,7 +216,7 @@ struct ggml_backend_registry {
dl_handle_ptr handle { dl_load_library(path) }; dl_handle_ptr handle { dl_load_library(path) };
if (!handle) { if (!handle) {
if (!silent) { if (!silent) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(path).c_str()); GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -233,7 +224,7 @@ struct ggml_backend_registry {
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (score_fn && score_fn() == 0) { if (score_fn && score_fn() == 0) {
if (!silent) { if (!silent) {
GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path_to_string(path).c_str()); GGML_LOG_INFO("%s: backend %s is not supported on this system\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -241,7 +232,7 @@ struct ggml_backend_registry {
auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init"); auto backend_init_fn = (ggml_backend_init_t) dl_get_sym(handle.get(), "ggml_backend_init");
if (!backend_init_fn) { if (!backend_init_fn) {
if (!silent) { if (!silent) {
GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path_to_string(path).c_str()); GGML_LOG_ERROR("%s: failed to find ggml_backend_init in %s\n", __func__, path.string().c_str());
} }
return nullptr; return nullptr;
} }
@@ -250,16 +241,16 @@ struct ggml_backend_registry {
if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) { if (!reg || reg->api_version != GGML_BACKEND_API_VERSION) {
if (!silent) { if (!silent) {
if (!reg) { if (!reg) {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path_to_string(path).c_str()); GGML_LOG_ERROR("%s: failed to initialize backend from %s: ggml_backend_init returned NULL\n", __func__, path.string().c_str());
} else { } else {
GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n", GGML_LOG_ERROR("%s: failed to initialize backend from %s: incompatible API version (backend: %d, current: %d)\n",
__func__, path_to_string(path).c_str(), reg->api_version, GGML_BACKEND_API_VERSION); __func__, path.string().c_str(), reg->api_version, GGML_BACKEND_API_VERSION);
} }
} }
return nullptr; return nullptr;
} }
GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_to_string(path).c_str()); GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path.string().c_str());
register_backend(reg, score_fn ? score_fn() : -1, std::move(handle)); register_backend(reg, score_fn ? score_fn() : -1, std::move(handle));
@@ -441,8 +432,9 @@ static std::filesystem::path get_executable_path() {
} }
return std::filesystem::path(path.data()).parent_path(); return std::filesystem::path(path.data()).parent_path();
#endif #else
return {}; return {};
#endif
} }
static std::string backend_filename_prefix() { static std::string backend_filename_prefix() {
@@ -491,18 +483,18 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
dl_handle_ptr handle { dl_load_library(entry.path()) }; dl_handle_ptr handle { dl_load_library(entry.path()) };
if (!handle) { if (!handle) {
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str()); GGML_LOG_ERROR("%s: failed to load %s\n", __func__, entry.path().string().c_str());
continue; continue;
} }
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
if (!score_fn) { if (!score_fn) {
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str()); GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, entry.path().string().c_str());
continue; continue;
} }
int s = score_fn(); int s = score_fn();
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s); GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, entry.path().string().c_str(), s);
if (s > best_score) { if (s > best_score) {
best_score = s; best_score = s;
best_path = entry.path(); best_path = entry.path();
@@ -510,7 +502,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
} }
} }
} catch (const std::exception & e) { } catch (const std::exception & e) {
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what()); GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, entry.path().string().c_str(), e.what());
} }
} }
} }

View File

@@ -1,59 +0,0 @@
package nn
import (
"fmt"
"github.com/ollama/ollama/ml"
)
// Attention implements scaled dot-product attention for transformer models:
// Attention(Q, K, V) = softmax(QK^T/√d_k)V
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)
if mask != nil {
kq = kq.Add(ctx, mask)
}
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}

View File

@@ -70,14 +70,14 @@ func Register(name string, f func(ml.Config) (Model, error)) {
} }
// New initializes a new model instance with the provided configuration based on the metadata in the model file // New initializes a new model instance with the provided configuration based on the metadata in the model file
func New(modelPath string, params ml.BackendParams) (Model, error) { func New(modelPath string) (Model, error) {
r, err := os.Open(modelPath) r, err := os.Open(modelPath)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer r.Close() defer r.Close()
b, err := ml.NewBackend(r, params) b, err := ml.NewBackend(r)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -86,8 +86,13 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) kq := k.MulmatFullPrec(ctx, q)
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor) kq = kq.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
kq = kq.Add(ctx, mask)
kq = kq.Softmax(ctx)
kqv := v.Mulmat(ctx, kq)
kqv = kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv) return sa.Output.Forward(ctx, kqv)
@@ -115,19 +120,11 @@ type Layer struct {
MLP *MLP MLP *MLP
} }
func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = l.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts) hiddenState = l.SelfAttention.Forward(ctx, hiddenState, positionIDs, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
@@ -147,26 +144,22 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
hiddenState = layer.Forward(ctx, hiddenState, positions, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
hiddenState = m.Output.Forward(ctx, hiddenState)
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
hiddenState := m.TokenEmbedding.Forward(ctx, inputs) return hiddenState.Rows(ctx, outputs), nil
for i, layer := range m.Layers {
m.Cache.SetLayer(i)
var lastLayerOutputs ml.Tensor
if i == len(m.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positions, lastLayerOutputs, m.Cache, m.Options)
}
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState), nil
} }
func init() { func init() {

View File

@@ -93,13 +93,15 @@ func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) {
return nil, err return nil, err
} }
// TODO: attention mask, cross attention mask
hiddenState := m.TextModel.Forward(ctx, inputs, positions, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache))
outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs)) outputs, err := ctx.FromIntSlice(opts.Outputs, len(opts.Outputs))
if err != nil { if err != nil {
return nil, err return nil, err
} }
// TODO: attention mask, cross attention mask return hiddenState.Rows(ctx, outputs), nil
return m.TextModel.Forward(ctx, inputs, positions, outputs, nil, crossAttentionStates, nil, m.Cache.(*kvcache.WrapperCache)), nil
} }
func init() { func init() {

View File

@@ -38,8 +38,13 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scores := key.MulmatFullPrec(ctx, query)
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Add(ctx, mask)
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention) return sa.Output.Forward(ctx, attention)
@@ -69,19 +74,11 @@ type TextSelfAttentionDecoderLayer struct {
MLP *TextMLP MLP *TextMLP
} }
func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, outputs, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextSelfAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, positions, mask, _, _ ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts) hiddenState = d.SelfAttention.Forward(ctx, hiddenState, positions, mask, cache, opts)
// In the final layer (outputs != nil), optimize by pruning to just the token positions
// we need logits for.
if outputs != nil {
hiddenState = hiddenState.Rows(ctx, outputs)
residual = residual.Rows(ctx, outputs)
}
hiddenState = hiddenState.Add(ctx, residual) hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState residual = hiddenState
@@ -107,7 +104,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize) query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps) query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor var key, value ml.Tensor
if crossAttentionStates != nil { if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2) numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -120,15 +117,19 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
cache.Put(ctx, key, value) cache.Put(ctx, key, value)
} else { } else {
key, value, mask = cache.Get(ctx) key, value, _ = cache.Get(ctx)
} }
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx) value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim)) scores := key.Mulmat(ctx, query)
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor) scores = scores.Scale(ctx, 1.0/math.Sqrt(float64(headDim)))
scores = scores.Softmax(ctx)
attention := value.Mulmat(ctx, scores)
attention = attention.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize) attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention) return ca.Output.Forward(ctx, attention)
@@ -144,7 +145,7 @@ type TextCrossAttentionDecoderLayer struct {
MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"` MLPGate ml.Tensor `gguf:"cross_attn_mlp_gate"`
} }
func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _, _, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
residual := hiddenState residual := hiddenState
hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps) hiddenState = d.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
@@ -160,14 +161,14 @@ func (d *TextCrossAttentionDecoderLayer) Forward(ctx ml.Context, hiddenState, _,
} }
type TextDecoderLayer interface { type TextDecoderLayer interface {
Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor
} }
type TextDecoder struct { type TextDecoder struct {
Layers []TextDecoderLayer Layers []TextDecoderLayer
} }
func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor { func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache, opts *TextModelOptions) ml.Tensor {
for i, layer := range d.Layers { for i, layer := range d.Layers {
layerType := selfAttentionLayer layerType := selfAttentionLayer
if slices.Contains(opts.crossAttentionLayers, uint32(i)) { if slices.Contains(opts.crossAttentionLayers, uint32(i)) {
@@ -178,12 +179,7 @@ func (d *TextDecoder) Forward(ctx ml.Context, hiddenState, positionIDs, outputs,
cache.SetLayerType(layerType) cache.SetLayerType(layerType)
if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() { if layerType == selfAttentionLayer || crossAttentionStates != nil || cache.UnderlyingCache().(*kvcache.EncoderCache).EncoderCached() {
var lastLayerOutputs ml.Tensor hiddenState = layer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
if i == len(d.Layers)-1 {
lastLayerOutputs = outputs
}
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, lastLayerOutputs, mask, crossAttentionStates, crossAttentionMask, cache, opts)
} }
} }
@@ -209,9 +205,9 @@ type TextModel struct {
*TextModelOptions *TextModelOptions
} }
func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor { func (m *TextModel) Forward(ctx ml.Context, inputIDs, positionIDs, mask, crossAttentionStates, crossAttentionMask ml.Tensor, cache *kvcache.WrapperCache) ml.Tensor {
hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs) hiddenState := m.TokenEmbedding.Forward(ctx, inputIDs)
hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, outputs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions) hiddenState = m.Transformer.Forward(ctx, hiddenState, positionIDs, mask, crossAttentionStates, crossAttentionMask, cache, m.TextModelOptions)
hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps) hiddenState = m.OutputNorm.Forward(ctx, hiddenState, m.eps)
return m.Output.Forward(ctx, hiddenState) return m.Output.Forward(ctx, hiddenState)
} }

View File

@@ -25,7 +25,6 @@ import (
"golang.org/x/sync/semaphore" "golang.org/x/sync/semaphore"
"github.com/ollama/ollama/api" "github.com/ollama/ollama/api"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
"github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/runner/common"
"github.com/ollama/ollama/sample" "github.com/ollama/ollama/sample"
@@ -802,7 +801,6 @@ func (m *multiLPath) String() string {
func (s *Server) loadModel( func (s *Server) loadModel(
mpath string, mpath string,
params ml.BackendParams,
lpath multiLPath, lpath multiLPath,
parallel int, parallel int,
kvCacheType string, kvCacheType string,
@@ -810,12 +808,12 @@ func (s *Server) loadModel(
multiUserCache bool, multiUserCache bool,
) { ) {
var err error var err error
s.model, err = model.New(mpath, params) s.model, err = model.New(mpath)
if err != nil { if err != nil {
panic(err) panic(err)
} }
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads) slog.Info("system", "info", s.model.Backend().SystemInfo() /* "threads", *threads */)
// TODO(jessegross): LoRA loading // TODO(jessegross): LoRA loading
if lpath.String() != "" { if lpath.String() != "" {
@@ -845,17 +843,17 @@ func Execute(args []string) error {
mpath := fs.String("model", "", "Path to model binary file") mpath := fs.String("model", "", "Path to model binary file")
parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously")
batchSize := fs.Int("batch-size", 512, "Batch size") batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") _ = fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU") _ = fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention") _ = fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on") port := fs.Int("port", 8080, "Port to expose the server on")
threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") _ = fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation")
verbose := fs.Bool("verbose", false, "verbose output (default: disabled)") verbose := fs.Bool("verbose", false, "verbose output (default: disabled)")
_ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)")
_ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing") _ = fs.Bool("mlock", false, "force system to keep model in RAM rather than swapping or compressing")
tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") _ = fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions")
multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users")
var lpaths multiLPath var lpaths multiLPath
@@ -892,11 +890,15 @@ func Execute(args []string) error {
} }
// TODO(jessegross): Parameters that need to be implemented: // TODO(jessegross): Parameters that need to be implemented:
// n-gpu-layers
// main-gpu
// flash-attn // flash-attn
// threads
// no-mmap // no-mmap
// mlock // mlock
// tensor-split
var tensorSplitFloats []float32 /*var tensorSplitFloats []float32
if *tensorSplit != "" { if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1) stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
@@ -905,17 +907,10 @@ func Execute(args []string) error {
f, _ := strconv.ParseFloat(s, 32) f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f)) tensorSplitFloats = append(tensorSplitFloats, float32(f))
} }
} }*/
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
}
server.ready.Add(1) server.ready.Add(1)
go server.loadModel(*mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) go server.loadModel(*mpath, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache)
server.cond = sync.NewCond(&server.mu) server.cond = sync.NewCond(&server.mu)

View File

@@ -1127,72 +1127,54 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
} }
func (s *Server) GenerateRoutes() http.Handler { func (s *Server) GenerateRoutes() http.Handler {
corsConfig := cors.DefaultConfig() config := cors.DefaultConfig()
corsConfig.AllowWildcard = true config.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true config.AllowBrowserExtensions = true
corsConfig.AllowHeaders = []string{ config.AllowHeaders = []string{"Authorization", "Content-Type", "User-Agent", "Accept", "X-Requested-With"}
"Authorization", openAIProperties := []string{"lang", "package-version", "os", "arch", "retry-count", "runtime", "runtime-version", "async", "helper-method", "poll-helper", "custom-poll-interval"}
"Content-Type", for _, prop := range openAIProperties {
"User-Agent", config.AllowHeaders = append(config.AllowHeaders, "x-stainless-"+prop)
"Accept",
"X-Requested-With",
// OpenAI compatibility headers
"x-stainless-lang",
"x-stainless-package-version",
"x-stainless-os",
"x-stainless-arch",
"x-stainless-retry-count",
"x-stainless-runtime",
"x-stainless-runtime-version",
"x-stainless-async",
"x-stainless-helper-method",
"x-stainless-poll-helper",
"x-stainless-custom-poll-interval",
"x-stainless-timeout",
} }
corsConfig.AllowOrigins = envconfig.AllowedOrigins() config.AllowOrigins = envconfig.Origins()
r := gin.Default() r := gin.Default()
r.Use( r.Use(
cors.New(corsConfig), cors.New(config),
allowedHostsMiddleware(s.addr), allowedHostsMiddleware(s.addr),
) )
// General
r.HEAD("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.GET("/", func(c *gin.Context) { c.String(http.StatusOK, "Ollama is running") })
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
r.POST("/api/pull", s.PullHandler) r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
// Create
r.POST("/api/create", s.CreateHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.POST("/api/copy", s.CopyHandler)
// Inference
r.GET("/api/ps", s.PsHandler)
r.POST("/api/generate", s.GenerateHandler) r.POST("/api/generate", s.GenerateHandler)
r.POST("/api/chat", s.ChatHandler) r.POST("/api/chat", s.ChatHandler)
r.POST("/api/embed", s.EmbedHandler) r.POST("/api/embed", s.EmbedHandler)
r.POST("/api/embeddings", s.EmbeddingsHandler) r.POST("/api/embeddings", s.EmbeddingsHandler)
r.POST("/api/create", s.CreateHandler)
r.POST("/api/push", s.PushHandler)
r.POST("/api/copy", s.CopyHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.POST("/api/show", s.ShowHandler)
r.POST("/api/blobs/:digest", s.CreateBlobHandler)
r.HEAD("/api/blobs/:digest", s.HeadBlobHandler)
r.GET("/api/ps", s.PsHandler)
// Inference (OpenAI compatibility) // Compatibility endpoints
r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler) r.POST("/v1/chat/completions", openai.ChatMiddleware(), s.ChatHandler)
r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler) r.POST("/v1/completions", openai.CompletionsMiddleware(), s.GenerateHandler)
r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler) r.POST("/v1/embeddings", openai.EmbeddingsMiddleware(), s.EmbedHandler)
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler) r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler) r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
})
r.Handle(method, "/api/tags", s.ListHandler)
r.Handle(method, "/api/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"version": version.Version})
})
}
return r return r
} }

View File

@@ -179,7 +179,7 @@ func (s *Scheduler) processPending(ctx context.Context) {
if allReliable { if allReliable {
// HACK // HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus))) os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(defaultModelsPerGPU*len(gpus)))
slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners(), "gpu_count", len(gpus)) slog.Debug("updating default concurrency", "OLLAMA_MAX_LOADED_MODELS", envconfig.MaxRunners, "gpu_count", len(gpus))
} else { } else {
// HACK // HACK
os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus))) os.Setenv("OLLAMA_MAX_LOADED_MODELS", strconv.Itoa(len(gpus)))