Compare commits
1 Commits
jmorganca/
...
jmorganca/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ad7e641815 |
10
api/types.go
10
api/types.go
@@ -159,15 +159,17 @@ type Runner struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingRequest struct {
|
||||||
Model string `json:"model"`
|
Model string `json:"model"`
|
||||||
Prompt string `json:"prompt"`
|
Prompt string `json:"prompt,omitempty"`
|
||||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
PromptBatch []string `json:"prompt_batch,omitempty"`
|
||||||
|
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||||
|
|
||||||
Options map[string]interface{} `json:"options"`
|
Options map[string]interface{} `json:"options"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embedding []float64 `json:"embedding,omitempty"`
|
||||||
|
EmbeddingBatch [][]float64 `json:"embedding_batch,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type CreateRequest struct {
|
type CreateRequest struct {
|
||||||
|
|||||||
33
docs/api.md
33
docs/api.md
@@ -1010,7 +1010,8 @@ Generate embeddings from a model
|
|||||||
### Parameters
|
### Parameters
|
||||||
|
|
||||||
- `model`: name of model to generate embeddings from
|
- `model`: name of model to generate embeddings from
|
||||||
- `prompt`: text to generate embeddings for
|
- `prompt`: string to generate the embedding for
|
||||||
|
- `prompts`: array of strings to generate a batch of embeddings for
|
||||||
|
|
||||||
Advanced parameters:
|
Advanced parameters:
|
||||||
|
|
||||||
@@ -1038,3 +1039,33 @@ curl http://localhost:11434/api/embeddings -d '{
|
|||||||
]
|
]
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Request (batch)
|
||||||
|
|
||||||
|
```shell
|
||||||
|
curl http://localhost:11434/api/embeddings -d '{
|
||||||
|
"model": "all-minilm",
|
||||||
|
"prompt_batch": [
|
||||||
|
"Here is an article about llamas...",
|
||||||
|
"Here is another article about llamas..."
|
||||||
|
]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Response
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"embedding_batch": [
|
||||||
|
[
|
||||||
|
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||||
|
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.5670403838157654, 0.009260174818336964, 0.23178744316101074, -0.2916173040866852, -0.8924556970596313,
|
||||||
|
0.8785552978515625, -0.34576427936553955, 0.5742510557174683, -0.04222835972905159, -0.137906014919281
|
||||||
|
],
|
||||||
|
]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|||||||
@@ -32,25 +32,9 @@ func PayloadsDir() (string, error) {
|
|||||||
slog.Error("failed to lookup executable path", "error", err)
|
slog.Error("failed to lookup executable path", "error", err)
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
cwd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
slog.Error("failed to lookup working directory", "error", err)
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
var paths []string
|
|
||||||
for _, root := range []string{appExe, cwd} {
|
|
||||||
paths = append(paths,
|
|
||||||
filepath.Join(root),
|
|
||||||
filepath.Join(root, "windows-"+runtime.GOARCH),
|
|
||||||
filepath.Join(root, "dist", "windows-"+runtime.GOARCH),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try a few variations to improve developer experience when building from source in the local tree
|
// Try a few variations to improve developer experience when building from source in the local tree
|
||||||
for _, p := range paths {
|
for _, d := range []string{".", "windows-" + runtime.GOARCH, "dist\\windows-" + runtime.GOARCH} {
|
||||||
candidate := filepath.Join(p, "ollama_runners")
|
candidate := filepath.Join(filepath.Dir(appExe), d, "ollama_runners")
|
||||||
_, err := os.Stat(candidate)
|
_, err := os.Stat(candidate)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
runnersDir = candidate
|
runnersDir = candidate
|
||||||
|
|||||||
64
integration/embedding_test.go
Normal file
64
integration/embedding_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
//go:build integration
|
||||||
|
|
||||||
|
package integration
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/ollama/ollama/api"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbedding(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Prompt: "why is the sky blue?",
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
|
||||||
|
|
||||||
|
if len(res.Embedding) != 384 {
|
||||||
|
t.Fatalf("Expected 384 floats to be returned, got %v", len(res.Embedding))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embedding[0] != 0.146763876080513 {
|
||||||
|
t.Fatalf("Expected first embedding float to be 0.146763876080513, got %v", res.Embedding[0])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAllMiniLMEmbeddings(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
req := api.EmbeddingRequest{
|
||||||
|
Model: "all-minilm",
|
||||||
|
Prompts: []string{"why is the sky blue?", "why is the sky blue?"},
|
||||||
|
Options: map[string]interface{}{
|
||||||
|
"temperature": 0,
|
||||||
|
"seed": 123,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
res := EmbeddingTestHelper(ctx, t, &http.Client{}, req)
|
||||||
|
|
||||||
|
if len(res.Embeddings) != 2 {
|
||||||
|
t.Fatal("Expected 2 embeddings to be returned")
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(res.Embeddings[0]) != 384 {
|
||||||
|
t.Fatalf("Expected first embedding to have 384 floats, got %v", len(res.Embeddings[0]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if res.Embeddings[0][0] != 0.146763876080513 && res.Embeddings[1][0] != 0.146763876080513 {
|
||||||
|
t.Fatalf("Expected first embedding floats to be 0.146763876080513, got %v, %v", res.Embeddings[0][0], res.Embeddings[1][0])
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ package integration
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
@@ -24,6 +25,7 @@ import (
|
|||||||
|
|
||||||
"github.com/ollama/ollama/api"
|
"github.com/ollama/ollama/api"
|
||||||
"github.com/ollama/ollama/app/lifecycle"
|
"github.com/ollama/ollama/app/lifecycle"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -285,6 +287,7 @@ func DoGenerate(ctx context.Context, t *testing.T, client *api.Client, genReq ap
|
|||||||
// Generate a set of requests
|
// Generate a set of requests
|
||||||
// By default each request uses orca-mini as the model
|
// By default each request uses orca-mini as the model
|
||||||
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
||||||
|
stream := false
|
||||||
return []api.GenerateRequest{
|
return []api.GenerateRequest{
|
||||||
{
|
{
|
||||||
Model: "orca-mini",
|
Model: "orca-mini",
|
||||||
@@ -336,3 +339,83 @@ func GenerateRequests() ([]api.GenerateRequest, [][]string) {
|
|||||||
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
[]string{"nitrogen", "oxygen", "carbon", "dioxide"},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func EmbeddingTestHelper(ctx context.Context, t *testing.T, client *http.Client, req api.EmbeddingRequest) api.EmbeddingResponse {
|
||||||
|
|
||||||
|
// TODO maybe stuff in an init routine?
|
||||||
|
lifecycle.InitLogging()
|
||||||
|
|
||||||
|
requestJSON, err := json.Marshal(req)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error serializing request: %v", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
|
defer serverProcMutex.Unlock()
|
||||||
|
if t.Failed() {
|
||||||
|
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, err := io.ReadAll(fp)
|
||||||
|
if err != nil {
|
||||||
|
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
slog.Warn("SERVER LOG FOLLOWS")
|
||||||
|
os.Stderr.Write(data)
|
||||||
|
slog.Warn("END OF SERVER")
|
||||||
|
}
|
||||||
|
err = os.Remove(lifecycle.ServerLogFile)
|
||||||
|
if err != nil && !os.IsNotExist(err) {
|
||||||
|
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scheme, testEndpoint := GetTestEndpoint()
|
||||||
|
|
||||||
|
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||||
|
serverProcMutex.Lock()
|
||||||
|
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to generate log file: %s", err)
|
||||||
|
}
|
||||||
|
lifecycle.ServerLogFile = fp.Name()
|
||||||
|
fp.Close()
|
||||||
|
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
err = PullIfMissing(ctx, client, scheme, testEndpoint, req.Model)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error pulling model: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make the request and get the response
|
||||||
|
httpReq, err := http.NewRequest("POST", scheme+"://"+testEndpoint+"/api/embeddings", bytes.NewReader(requestJSON))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error creating request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the content type for the request
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Make the request with the HTTP client
|
||||||
|
response, err := client.Do(httpReq.WithContext(ctx))
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Error making request: %v", err)
|
||||||
|
}
|
||||||
|
defer response.Body.Close()
|
||||||
|
body, err := io.ReadAll(response.Body)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, response.StatusCode, 200, string(body))
|
||||||
|
|
||||||
|
// Verify the response is valid JSON
|
||||||
|
var res api.EmbeddingResponse
|
||||||
|
err = json.Unmarshal(body, &res)
|
||||||
|
if err != nil {
|
||||||
|
assert.NoError(t, err, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
||||||
|
|||||||
55
llm/ext_server/server.cpp
vendored
55
llm/ext_server/server.cpp
vendored
@@ -3209,54 +3209,27 @@ int main(int argc, char **argv) {
|
|||||||
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
return res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
});
|
});
|
||||||
|
|
||||||
svr.Post("/embedding", [&llama](const httplib::Request &req, httplib::Response &res)
|
svr.Post("/embeddings", [&llama](const httplib::Request &req, httplib::Response &res)
|
||||||
{
|
{
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
const json body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
json prompt;
|
|
||||||
if (body.count("content") != 0)
|
const int id = llama.queue_tasks.get_new_id();
|
||||||
{
|
llama.queue_results.add_waiting_task_id(id);
|
||||||
prompt = body["content"];
|
llama.request_completion(id, {{"prompt", body["contents"]}}, false, true, -1);
|
||||||
}
|
|
||||||
else
|
task_result recv = llama.queue_results.recv(id);
|
||||||
{
|
llama.queue_results.remove_waiting_task_id(id);
|
||||||
prompt = "";
|
|
||||||
|
json embeddings = json::array();
|
||||||
|
for (auto & elem : recv.result_json["results"]) {
|
||||||
|
embeddings.push_back(json_value(elem, "embedding", json::array()));
|
||||||
}
|
}
|
||||||
|
|
||||||
json image_data;
|
json result = json{{"embeddings", embeddings}};
|
||||||
if (body.count("image_data") != 0) {
|
return res.set_content(result.dump(), "application/json; charset=utf-8");
|
||||||
image_data = body["image_data"];
|
|
||||||
}
|
|
||||||
else
|
|
||||||
{
|
|
||||||
image_data = "";
|
|
||||||
}
|
|
||||||
|
|
||||||
// create and queue the task
|
|
||||||
const int task_id = llama.queue_tasks.get_new_id();
|
|
||||||
llama.queue_results.add_waiting_task_id(task_id);
|
|
||||||
llama.request_completion(task_id, { {"prompt", prompt}, { "n_predict", 0}, {"image_data", image_data} }, false, true, -1);
|
|
||||||
|
|
||||||
// get the result
|
|
||||||
task_result result = llama.queue_results.recv(task_id);
|
|
||||||
llama.queue_results.remove_waiting_task_id(task_id);
|
|
||||||
|
|
||||||
// send the result
|
|
||||||
return res.set_content(result.result_json.dump(), "application/json; charset=utf-8");
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// GG: if I put the main loop inside a thread, it crashes on the first request when build in Debug!?
|
|
||||||
// "Bus error: 10" - this is on macOS, it does not crash on Linux
|
|
||||||
//std::thread t2([&]()
|
|
||||||
/*{
|
|
||||||
bool running = true;
|
|
||||||
while (running)
|
|
||||||
{
|
|
||||||
running = llama.update_slots();
|
|
||||||
}
|
|
||||||
}*/
|
|
||||||
//);
|
|
||||||
|
|
||||||
if (sparams.n_threads_http < 1) {
|
if (sparams.n_threads_http < 1) {
|
||||||
// +2 threads for monitoring endpoints
|
// +2 threads for monitoring endpoints
|
||||||
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
sparams.n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1);
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ function init_vars {
|
|||||||
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||||
$script:ARCH = "amd64" # arm not yet supported.
|
$script:ARCH = "amd64" # arm not yet supported.
|
||||||
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
$script:DIST_BASE = "${script:SRC_DIR}\dist\windows-${script:ARCH}\ollama_runners"
|
||||||
md "$script:DIST_BASE" -ea 0 > $null
|
|
||||||
if ($env:CGO_CFLAGS -contains "-g") {
|
if ($env:CGO_CFLAGS -contains "-g") {
|
||||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
||||||
$script:config = "RelWithDebInfo"
|
$script:config = "RelWithDebInfo"
|
||||||
@@ -182,7 +181,7 @@ function cleanup {
|
|||||||
|
|
||||||
|
|
||||||
function build_static() {
|
function build_static() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_STATIC_GENERATE}") -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "static"))) {
|
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||||
# GCC build for direct linking into the Go binary
|
# GCC build for direct linking into the Go binary
|
||||||
init_vars
|
init_vars
|
||||||
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
||||||
@@ -214,7 +213,7 @@ function build_static() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function build_cpu() {
|
function build_cpu() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu"))) {
|
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||||
# remaining llama.cpp builds use MSVC
|
# remaining llama.cpp builds use MSVC
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||||
@@ -230,7 +229,7 @@ function build_cpu() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
function build_cpu_avx() {
|
function build_cpu_avx() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx"))) {
|
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
||||||
@@ -240,12 +239,12 @@ function build_cpu_avx() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
} else {
|
} else {
|
||||||
write-host "Skipping CPU AVX generation step as requested"
|
write-host "Skipping CPU generation step as requested"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_cpu_avx2() {
|
function build_cpu_avx2() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CPU_GENERATE}" ) -and ((-not "${env:OLLAMA_CPU_TARGET}") -or ("${env:OLLAMA_CPU_TARGET}" -eq "cpu_avx2"))) {
|
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||||
init_vars
|
init_vars
|
||||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
||||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
||||||
@@ -255,12 +254,12 @@ function build_cpu_avx2() {
|
|||||||
sign
|
sign
|
||||||
install
|
install
|
||||||
} else {
|
} else {
|
||||||
write-host "Skipping CPU AVX2 generation step as requested"
|
write-host "Skipping CPU generation step as requested"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_cuda() {
|
function build_cuda() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_CUDA_GENERATE}") -and ("${script:CUDA_LIB_DIR}")) {
|
if ($null -ne $script:CUDA_LIB_DIR) {
|
||||||
# Then build cuda as a dynamically loaded library
|
# Then build cuda as a dynamically loaded library
|
||||||
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
|
$nvcc = "$script:CUDA_LIB_DIR\nvcc.exe"
|
||||||
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
|
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
|
||||||
@@ -284,13 +283,11 @@ function build_cuda() {
|
|||||||
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
cp "${script:CUDA_LIB_DIR}\cudart64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
cp "${script:CUDA_LIB_DIR}\cublas64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
cp "${script:CUDA_LIB_DIR}\cublasLt64_*.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\"
|
||||||
} else {
|
|
||||||
write-host "Skipping CUDA generation step"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
function build_rocm() {
|
function build_rocm() {
|
||||||
if ((-not "${env:OLLAMA_SKIP_ROCM_GENERATE}") -and ("${env:HIP_PATH}")) {
|
if ($null -ne $env:HIP_PATH) {
|
||||||
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
$script:ROCM_VERSION=(get-item $env:HIP_PATH).Basename
|
||||||
if ($null -ne $script:ROCM_VERSION) {
|
if ($null -ne $script:ROCM_VERSION) {
|
||||||
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
|
$script:ROCM_VARIANT="_v"+$script:ROCM_VERSION
|
||||||
@@ -339,8 +336,6 @@ function build_rocm() {
|
|||||||
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
cp "${env:HIP_PATH}\bin\rocblas.dll" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\"
|
||||||
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
|
# amdhip64.dll dependency comes from the driver and must be installed on the host to use AMD GPUs
|
||||||
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
|
cp "${env:HIP_PATH}\bin\rocblas\library\*" "${script:SRC_DIR}\dist\windows-${script:ARCH}\rocm\rocblas\library\"
|
||||||
} else {
|
|
||||||
write-host "Skipping ROCm generation step"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ type LlamaServer interface {
|
|||||||
Ping(ctx context.Context) error
|
Ping(ctx context.Context) error
|
||||||
WaitUntilRunning(ctx context.Context) error
|
WaitUntilRunning(ctx context.Context) error
|
||||||
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error
|
||||||
Embedding(ctx context.Context, prompt string) ([]float64, error)
|
Embeddings(ctx context.Context, prompt []string) ([][]float64, error)
|
||||||
Tokenize(ctx context.Context, content string) ([]int, error)
|
Tokenize(ctx context.Context, content string) ([]int, error)
|
||||||
Detokenize(ctx context.Context, tokens []int) (string, error)
|
Detokenize(ctx context.Context, tokens []int) (string, error)
|
||||||
Close() error
|
Close() error
|
||||||
@@ -736,15 +736,15 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
|||||||
return fmt.Errorf("max retries exceeded")
|
return fmt.Errorf("max retries exceeded")
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingRequest struct {
|
type EmbeddingsRequest struct {
|
||||||
Content string `json:"content"`
|
Contents []string `json:"contents"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type EmbeddingResponse struct {
|
type EmbeddingsResponse struct {
|
||||||
Embedding []float64 `json:"embedding"`
|
Embeddings [][]float64 `json:"embeddings"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *llmServer) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
|
||||||
if err := s.sem.Acquire(ctx, 1); err != nil {
|
if err := s.sem.Acquire(ctx, 1); err != nil {
|
||||||
slog.Error("Failed to acquire semaphore", "error", err)
|
slog.Error("Failed to acquire semaphore", "error", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -758,12 +758,12 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||||
}
|
}
|
||||||
|
|
||||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
data, err := json.Marshal(EmbeddingsRequest{Contents: prompts})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embeddings", s.port), bytes.NewBuffer(data))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||||
}
|
}
|
||||||
@@ -780,17 +780,19 @@ func (s *llmServer) Embedding(ctx context.Context, prompt string) ([]float64, er
|
|||||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fmt.Println("embeddings response", string(body))
|
||||||
|
|
||||||
if resp.StatusCode >= 400 {
|
if resp.StatusCode >= 400 {
|
||||||
log.Printf("llm encode error: %s", body)
|
log.Printf("llm encode error: %s", body)
|
||||||
return nil, fmt.Errorf("%s", body)
|
return nil, fmt.Errorf("%s", body)
|
||||||
}
|
}
|
||||||
|
|
||||||
var embedding EmbeddingResponse
|
var embedding EmbeddingsResponse
|
||||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return embedding.Embedding, nil
|
return embedding.Embeddings, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenizeRequest struct {
|
type TokenizeRequest struct {
|
||||||
|
|||||||
@@ -403,23 +403,39 @@ func (s *Server) EmbeddingsHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// an empty request loads the model
|
switch {
|
||||||
if req.Prompt == "" {
|
// single embedding
|
||||||
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
case len(req.Prompt) > 0:
|
||||||
return
|
embeddings, err := runner.llama.Embeddings(c.Request.Context(), []string{req.Prompt})
|
||||||
}
|
if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
embedding, err := runner.llama.Embedding(c.Request.Context(), req.Prompt)
|
resp := api.EmbeddingResponse{Embedding: embeddings[0]}
|
||||||
if err != nil {
|
c.JSON(http.StatusOK, resp)
|
||||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
|
||||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := api.EmbeddingResponse{
|
// batch embeddings
|
||||||
Embedding: embedding,
|
case len(req.PromptBatch) > 0:
|
||||||
|
embeddings, err := runner.llama.Embeddings(c.Request.Context(), req.PromptBatch)
|
||||||
|
if err != nil {
|
||||||
|
slog.Info(fmt.Sprintf("batch embedding generation failed: %v", err))
|
||||||
|
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := api.EmbeddingResponse{EmbeddingBatch: embeddings}
|
||||||
|
c.JSON(http.StatusOK, resp)
|
||||||
|
|
||||||
|
// empty prompt loads the model
|
||||||
|
default:
|
||||||
|
if req.PromptBatch != nil {
|
||||||
|
c.JSON(http.StatusOK, api.EmbeddingResponse{EmbeddingBatch: [][]float64{}})
|
||||||
|
} else {
|
||||||
|
c.JSON(http.StatusOK, api.EmbeddingResponse{Embedding: []float64{}})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.JSON(http.StatusOK, resp)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) PullModelHandler(c *gin.Context) {
|
func (s *Server) PullModelHandler(c *gin.Context) {
|
||||||
|
|||||||
@@ -530,7 +530,7 @@ type mockLlm struct {
|
|||||||
pingResp error
|
pingResp error
|
||||||
waitResp error
|
waitResp error
|
||||||
completionResp error
|
completionResp error
|
||||||
embeddingResp []float64
|
embeddingResp [][]float64
|
||||||
embeddingRespErr error
|
embeddingRespErr error
|
||||||
tokenizeResp []int
|
tokenizeResp []int
|
||||||
tokenizeRespErr error
|
tokenizeRespErr error
|
||||||
@@ -546,7 +546,7 @@ func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitRes
|
|||||||
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error {
|
||||||
return s.completionResp
|
return s.completionResp
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
func (s *mockLlm) Embeddings(ctx context.Context, prompts []string) ([][]float64, error) {
|
||||||
return s.embeddingResp, s.embeddingRespErr
|
return s.embeddingResp, s.embeddingRespErr
|
||||||
}
|
}
|
||||||
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
func (s *mockLlm) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"cmp"
|
"cmp"
|
||||||
|
"encoding/hex"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -170,6 +171,11 @@ func Merge(a, b Name) Name {
|
|||||||
return a
|
return a
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Digest returns the result of [ParseDigest] with the RawDigest field.
|
||||||
|
func (n Name) Digest() Digest {
|
||||||
|
return ParseDigest(n.RawDigest)
|
||||||
|
}
|
||||||
|
|
||||||
// String returns the name string, in the format that [ParseNameNoDefaults]
|
// String returns the name string, in the format that [ParseNameNoDefaults]
|
||||||
// accepts as valid, if [Name.IsValid] reports true; otherwise the empty
|
// accepts as valid, if [Name.IsValid] reports true; otherwise the empty
|
||||||
// string is returned.
|
// string is returned.
|
||||||
@@ -198,7 +204,7 @@ func (n Name) String() string {
|
|||||||
// IsValid reports whether all parts of the name are present and valid. The
|
// IsValid reports whether all parts of the name are present and valid. The
|
||||||
// digest is a special case, and is checked for validity only if present.
|
// digest is a special case, and is checked for validity only if present.
|
||||||
func (n Name) IsValid() bool {
|
func (n Name) IsValid() bool {
|
||||||
if n.RawDigest != "" && !isValidPart(kindDigest, n.RawDigest) {
|
if n.RawDigest != "" && !ParseDigest(n.RawDigest).IsValid() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return n.IsFullyQualified()
|
return n.IsFullyQualified()
|
||||||
@@ -276,7 +282,7 @@ func isValidPart(kind partKind, s string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
case ':':
|
case ':':
|
||||||
if kind != kindHost && kind != kindDigest {
|
if kind != kindHost {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
@@ -311,3 +317,75 @@ func cutPromised(s, sep string) (before, after string, ok bool) {
|
|||||||
}
|
}
|
||||||
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
return cmp.Or(before, MissingPart), cmp.Or(after, MissingPart), true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DigestType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
DigestTypeInvalid DigestType = iota
|
||||||
|
DigestTypeSHA256
|
||||||
|
)
|
||||||
|
|
||||||
|
func (t DigestType) String() string {
|
||||||
|
if t == DigestTypeSHA256 {
|
||||||
|
return "sha256"
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Digest represents a type and hash of a digest. It is comparable and can
|
||||||
|
// be used as a map key.
|
||||||
|
type Digest struct {
|
||||||
|
Type DigestType
|
||||||
|
Hash [32]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseDigest parses a digest string into a Digest struct. It accepts both
|
||||||
|
// the forms:
|
||||||
|
//
|
||||||
|
// sha256:deadbeef
|
||||||
|
// sha256-deadbeef
|
||||||
|
//
|
||||||
|
// The hash part must be exactly 64 characters long.
|
||||||
|
//
|
||||||
|
// The form "type:hash" does not round trip through [Digest.String].
|
||||||
|
func ParseDigest(s string) Digest {
|
||||||
|
typ, hash, ok := cutLast(s, ":")
|
||||||
|
if !ok {
|
||||||
|
typ, hash, ok = cutLast(s, "-")
|
||||||
|
if !ok {
|
||||||
|
return Digest{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if typ != "sha256" {
|
||||||
|
return Digest{}
|
||||||
|
}
|
||||||
|
var d Digest
|
||||||
|
n, err := hex.Decode(d.Hash[:], []byte(hash))
|
||||||
|
if err != nil || n != 32 {
|
||||||
|
return Digest{}
|
||||||
|
}
|
||||||
|
return Digest{Type: DigestTypeSHA256, Hash: d.Hash}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValid returns true if the digest has a valid Type and Hash.
|
||||||
|
func (d Digest) IsValid() bool {
|
||||||
|
if d.Type != DigestTypeSHA256 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return d.Hash != [32]byte{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the digest as a string in the form "type-hash". The hash
|
||||||
|
// is encoded as a hex string.
|
||||||
|
func (d Digest) String() string {
|
||||||
|
var b strings.Builder
|
||||||
|
b.WriteString(d.Type.String())
|
||||||
|
b.WriteByte('-')
|
||||||
|
b.WriteString(hex.EncodeToString(d.Hash[:]))
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogValue returns a slog.Value that represents the digest as a string.
|
||||||
|
func (d Digest) LogValue() slog.Value {
|
||||||
|
return slog.StringValue(d.String())
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"reflect"
|
"reflect"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -81,10 +82,10 @@ func TestParseNameParts(t *testing.T) {
|
|||||||
wantValidDigest: false,
|
wantValidDigest: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
in: "model@sha256:123",
|
in: "model@sha256:" + validSHA256Hex,
|
||||||
want: Name{
|
want: Name{
|
||||||
Model: "model",
|
Model: "model",
|
||||||
RawDigest: "sha256:123",
|
RawDigest: "sha256:" + validSHA256Hex,
|
||||||
},
|
},
|
||||||
wantValidDigest: true,
|
wantValidDigest: true,
|
||||||
},
|
},
|
||||||
@@ -96,6 +97,9 @@ func TestParseNameParts(t *testing.T) {
|
|||||||
if !reflect.DeepEqual(got, tt.want) {
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
|
t.Errorf("parseName(%q) = %v; want %v", tt.in, got, tt.want)
|
||||||
}
|
}
|
||||||
|
if got.Digest().IsValid() != tt.wantValidDigest {
|
||||||
|
t.Errorf("parseName(%q).Digest().IsValid() = %v; want %v", tt.in, got.Digest().IsValid(), tt.wantValidDigest)
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -235,3 +239,57 @@ func FuzzName(f *testing.F) {
|
|||||||
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const validSHA256Hex = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789"
|
||||||
|
|
||||||
|
func TestParseDigest(t *testing.T) {
|
||||||
|
cases := map[string]bool{
|
||||||
|
"sha256-1000000000000000000000000000000000000000000000000000000000000000": true,
|
||||||
|
"sha256:1000000000000000000000000000000000000000000000000000000000000000": true,
|
||||||
|
"sha256:0000000000000000000000000000000000000000000000000000000000000000": false,
|
||||||
|
|
||||||
|
"sha256:" + validSHA256Hex: true,
|
||||||
|
"sha256-" + validSHA256Hex: true,
|
||||||
|
|
||||||
|
"": false,
|
||||||
|
"sha134:" + validSHA256Hex: false,
|
||||||
|
"sha256:" + validSHA256Hex + "x": false,
|
||||||
|
"sha256:x" + validSHA256Hex: false,
|
||||||
|
"sha256-" + validSHA256Hex + "x": false,
|
||||||
|
"sha256-x": false,
|
||||||
|
}
|
||||||
|
|
||||||
|
for s, want := range cases {
|
||||||
|
t.Run(s, func(t *testing.T) {
|
||||||
|
d := ParseDigest(s)
|
||||||
|
if d.IsValid() != want {
|
||||||
|
t.Errorf("ParseDigest(%q).IsValid() = %v; want %v", s, d.IsValid(), want)
|
||||||
|
}
|
||||||
|
norm := strings.ReplaceAll(s, ":", "-")
|
||||||
|
if d.IsValid() && d.String() != norm {
|
||||||
|
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, d.String(), norm)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDigestString(t *testing.T) {
|
||||||
|
cases := []struct {
|
||||||
|
in string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{in: "sha256:" + validSHA256Hex, want: "sha256-" + validSHA256Hex},
|
||||||
|
{in: "sha256-" + validSHA256Hex, want: "sha256-" + validSHA256Hex},
|
||||||
|
{in: "", want: "unknown-0000000000000000000000000000000000000000000000000000000000000000"},
|
||||||
|
{in: "blah-100000000000000000000000000000000000000000000000000000000000000", want: "unknown-0000000000000000000000000000000000000000000000000000000000000000"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range cases {
|
||||||
|
t.Run(tt.in, func(t *testing.T) {
|
||||||
|
d := ParseDigest(tt.in)
|
||||||
|
if d.String() != tt.want {
|
||||||
|
t.Errorf("ParseDigest(%q).String() = %q; want %q", tt.in, d.String(), tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user