Compare commits
46 Commits
v0.5.13-rc
...
brucemacd/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81888abbe4 | ||
|
|
05a01fdecb | ||
|
|
8fe6f69f28 | ||
|
|
1fdb351c37 | ||
|
|
7a01ad7614 | ||
|
|
55ab9f371a | ||
|
|
fefbf8f74b | ||
|
|
b428ddd796 | ||
|
|
ba7d31240e | ||
|
|
d25efe3954 | ||
|
|
36dfb906bb | ||
|
|
a6f0f908b9 | ||
|
|
3b1ddb2b3a | ||
|
|
1579c4f06d | ||
|
|
3519dd1c6e | ||
|
|
e41c4cbea7 | ||
|
|
ee048b76d4 | ||
|
|
af68d60a58 | ||
|
|
21aa666a1e | ||
|
|
ee141cc821 | ||
|
|
55e5776c44 | ||
|
|
854a9195f3 | ||
|
|
96a97adf9b | ||
|
|
e75c6126e9 | ||
|
|
cda6f5c66c | ||
|
|
bebb6823c0 | ||
|
|
31e472baa4 | ||
|
|
657685e85d | ||
|
|
a14912858e | ||
|
|
eed11ded30 | ||
|
|
b42aba40ed | ||
|
|
25885e5335 | ||
|
|
98d44fa39d | ||
|
|
2099e2d267 | ||
|
|
0c1041ad85 | ||
|
|
c245b0406f | ||
|
|
8b194b7520 | ||
|
|
3e8b8a1933 | ||
|
|
41dc280491 | ||
|
|
53d2990d9b | ||
|
|
e185c08ad9 | ||
|
|
2412adf42b | ||
|
|
be2ac1ed93 | ||
|
|
dc13813a03 | ||
|
|
d6af13efed | ||
|
|
a59f665235 |
@@ -23,6 +23,7 @@ set(GGML_SCHED_MAX_COPIES 4)
|
||||
set(GGML_LLAMAFILE ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
set(GGML_CUDA_FA ON)
|
||||
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||
@@ -105,9 +106,11 @@ if(CMAKE_HIP_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
|
||||
|
||||
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
||||
|
||||
## Pull requests
|
||||
|
||||
### Ideal issues
|
||||
|
||||
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
||||
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
|
||||
* Changes that add significant friction to the user experience
|
||||
* Changes that create a large future maintenance burden for maintainers and contributors
|
||||
|
||||
### Best practices
|
||||
## Proposing a (non-trivial) change
|
||||
|
||||
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
|
||||
* Tests: please add test coverage to changes where possible.
|
||||
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
|
||||
> By "non-trivial", we mean a change that is not a bug fix or small
|
||||
> documentation update. If you are unsure, please ask us on our [Discord
|
||||
> server](https://discord.gg/ollama).
|
||||
|
||||
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
|
||||
get feedback from the maintainers. This helps us understand the context of the
|
||||
change and how it fits into Ollama's roadmap and prevents us from duplicating
|
||||
work or you from spending time on a change that we may not be able to accept.
|
||||
|
||||
Tips for proposals:
|
||||
|
||||
* Explain the problem you are trying to solve, not what you are trying to do.
|
||||
* Explain why the change is important.
|
||||
* Explain how the change will be used.
|
||||
* Explain how the change will be tested.
|
||||
|
||||
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||
see if the change were accepted.
|
||||
|
||||
## Pull requests
|
||||
|
||||
**Commit messages**
|
||||
|
||||
The title should look like:
|
||||
|
||||
<package>: <short description>
|
||||
|
||||
The package is the most affected Go package. If the change does not affect Go
|
||||
code, then use the directory name instead. Changes to a single well-known
|
||||
file in the root directory may use the file name.
|
||||
|
||||
The short description should start with a lowercase letter and be a
|
||||
continuation of the sentence:
|
||||
|
||||
"This changes Ollama to..."
|
||||
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
|
||||
Bad Examples:
|
||||
|
||||
feat: add more emoji
|
||||
fix: was not using famous web framework
|
||||
chore: generify code
|
||||
|
||||
**Tests**
|
||||
|
||||
Please include tests. Strive to test behavior, not implementation.
|
||||
|
||||
**New dependencies**
|
||||
|
||||
Dependencies should be added sparingly. If you are adding a new dependency,
|
||||
please explain why it is necessary and what other ways you attempted that
|
||||
did not work without it.
|
||||
|
||||
## Need help?
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 gcc-toolset-10-binutils-2.35-11.el8 \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
@@ -86,10 +86,11 @@ RUN --mount=type=cache,target=/root/.ccache \
|
||||
&& cmake --install build --component CUDA --strip --parallel 8
|
||||
|
||||
FROM base AS build
|
||||
ARG GOVERSION=1.23.4
|
||||
RUN curl -fsSL https://golang.org/dl/go${GOVERSION}.linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY go.mod go.sum .
|
||||
RUN curl -fsSL https://golang.org/dl/go$(awk '/^go/ { print $2 }' go.mod).linux-$(case $(uname -m) in x86_64) echo amd64 ;; aarch64) echo arm64 ;; esac).tar.gz | tar xz -C /usr/local
|
||||
ENV PATH=/usr/local/go/bin:$PATH
|
||||
RUN go mod download
|
||||
COPY . .
|
||||
ARG GOFLAGS="'-ldflags=-w -s'"
|
||||
ENV CGO_ENABLED=1
|
||||
|
||||
11
README.md
11
README.md
@@ -1,5 +1,5 @@
|
||||
<div align="center">
|
||||
<a href="https://ollama.com" />
|
||||
<a href="https://ollama.com">
|
||||
<img alt="ollama" height="200px" src="https://github.com/ollama/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
|
||||
</a>
|
||||
</div>
|
||||
@@ -64,7 +64,7 @@ Here are some example models that can be downloaded:
|
||||
| Llama 3.1 | 8B | 4.7GB | `ollama run llama3.1` |
|
||||
| Llama 3.1 | 405B | 231GB | `ollama run llama3.1:405b` |
|
||||
| Phi 4 | 14B | 9.1GB | `ollama run phi4` |
|
||||
| Phi 3 Mini | 3.8B | 2.3GB | `ollama run phi3` |
|
||||
| Phi 4 Mini | 3.8B | 2.5GB | `ollama run phi4-mini` |
|
||||
| Gemma 2 | 2B | 1.6GB | `ollama run gemma2:2b` |
|
||||
| Gemma 2 | 9B | 5.5GB | `ollama run gemma2` |
|
||||
| Gemma 2 | 27B | 16GB | `ollama run gemma2:27b` |
|
||||
@@ -75,7 +75,7 @@ Here are some example models that can be downloaded:
|
||||
| Code Llama | 7B | 3.8GB | `ollama run codellama` |
|
||||
| Llama 2 Uncensored | 7B | 3.8GB | `ollama run llama2-uncensored` |
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
| Granite-3.2 | 8B | 4.9GB | `ollama run granite3.2` |
|
||||
|
||||
> [!NOTE]
|
||||
> You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
@@ -386,6 +386,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
### Cloud
|
||||
|
||||
@@ -510,6 +513,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Maid](https://github.com/Mobile-Artificial-Intelligence/maid)
|
||||
- [Ollama App](https://github.com/JHubi1/ollama-app) (Modern and easy-to-use multi-platform client for Ollama)
|
||||
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy focused LLM chat interface with optional encryption)
|
||||
- [Ollama Android Chat](https://github.com/sunshine0523/OllamaServer) (No need for Termux, start the Ollama service with one click on an Android device)
|
||||
- [Reins](https://github.com/ibrahimcetin/reins) (Easily tweak parameters, customize system prompts per chat, and enhance your AI experiments with reasoning model support.)
|
||||
|
||||
### Extensions & Plugins
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
// repository].
|
||||
//
|
||||
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
|
||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
|
||||
package api
|
||||
|
||||
import (
|
||||
|
||||
17
cmd/cmd.go
17
cmd/cmd.go
@@ -34,7 +34,6 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/progress"
|
||||
"github.com/ollama/ollama/runner"
|
||||
@@ -256,6 +255,7 @@ func StopHandler(cmd *cobra.Command, args []string) error {
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
return fmt.Errorf("couldn't find model \"%s\" to stop", args[0])
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -338,10 +338,16 @@ func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO(jessegross): We should either find another way to know if this is
|
||||
// a vision model or remove the logic. Also consider that other modalities will
|
||||
// need different behavior anyways.
|
||||
opts.MultiModal = len(info.ProjectorInfo) != 0 || envconfig.NewEngine()
|
||||
if len(info.ProjectorInfo) != 0 {
|
||||
opts.MultiModal = true
|
||||
}
|
||||
for k := range info.ModelInfo {
|
||||
if strings.Contains(k, ".vision.") {
|
||||
opts.MultiModal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
opts.ParentModel = info.Details.ParentModel
|
||||
|
||||
if interactive {
|
||||
@@ -1274,7 +1280,6 @@ func NewCLI() *cobra.Command {
|
||||
|
||||
runnerCmd := &cobra.Command{
|
||||
Use: "runner",
|
||||
Short: llama.PrintSystemInfo(),
|
||||
Hidden: true,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
return runner.Execute(os.Args[1:])
|
||||
|
||||
@@ -118,6 +118,35 @@ To run tests, use `go test`:
|
||||
go test ./...
|
||||
```
|
||||
|
||||
> NOTE: In rare cirumstances, you may nedd to change a package using the new
|
||||
> "synctest" package in go1.24.
|
||||
>
|
||||
> If you do not have the "synctest" package enabled, you will not see build or
|
||||
> test failures resulting from your change(s), if any, locally, but CI will
|
||||
> break.
|
||||
>
|
||||
> If you see failures in CI, you can either keep pushing changes to see if the
|
||||
> CI build passes, or you can enable the "synctest" package locally to see the
|
||||
> failures before pushing.
|
||||
>
|
||||
> To enable the "synctest" package for testing, run the following command:
|
||||
>
|
||||
> ```shell
|
||||
> GOEXPERIMENT=synctest go test ./...
|
||||
> ```
|
||||
>
|
||||
> If you wish to enable synctest for all go commands, you can set the
|
||||
> `GOEXPERIMENT` environment variable in your shell profile or by using:
|
||||
>
|
||||
> ```shell
|
||||
> go env -w GOEXPERIMENT=synctest
|
||||
> ```
|
||||
>
|
||||
> Which will enable the "synctest" package for all go commands without needing
|
||||
> to set it for all shell sessions.
|
||||
>
|
||||
> The synctest package is not required for production builds.
|
||||
|
||||
## Library detection
|
||||
|
||||
Ollama looks for acceleration libraries in the following paths relative to the `ollama` executable:
|
||||
|
||||
@@ -73,6 +73,7 @@ func AllowedOrigins() (origins []string) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
)
|
||||
|
||||
return origins
|
||||
|
||||
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://10.0.0.1", []string{
|
||||
"http://10.0.0.1",
|
||||
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://172.16.0.1,https://192.168.0.1", []string{
|
||||
"http://172.16.0.1",
|
||||
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://totally.safe,http://definitely.legit", []string{
|
||||
"http://totally.safe",
|
||||
@@ -128,6 +131,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]string, r.size)
|
||||
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return s
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
@@ -561,6 +565,43 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
|
||||
return
|
||||
}
|
||||
|
||||
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
|
||||
switch llm.KV().Architecture() {
|
||||
case "mllama":
|
||||
for _, layer := range llm.Tensors().GroupLayers()["v"] {
|
||||
weights += layer.Size()
|
||||
}
|
||||
|
||||
kv := func(n string) uint64 {
|
||||
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
|
||||
return uint64(v)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
imageSize := kv("image_size")
|
||||
|
||||
maxNumTiles := kv("max_num_tiles")
|
||||
embeddingLength := kv("embedding_length")
|
||||
headCount := kv("attention.head_count")
|
||||
|
||||
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
|
||||
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
|
||||
numPatches++
|
||||
}
|
||||
|
||||
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
|
||||
|
||||
graphSize = 4 * (8 +
|
||||
imageSize*imageSize*kv("num_channels")*maxNumTiles +
|
||||
embeddingLength*numPatches*maxNumTiles +
|
||||
9*embeddingLength*numPaddedPatches*maxNumTiles +
|
||||
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
|
||||
}
|
||||
return weights, graphSize
|
||||
}
|
||||
|
||||
// SupportsKVCacheType checks if the requested cache type is supported
|
||||
func (f GGML) SupportsKVCacheType(cacheType string) bool {
|
||||
return slices.Contains([]string{"f16", "q8_0", "q4_0"}, cacheType)
|
||||
|
||||
16
go.mod
16
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/ollama/ollama
|
||||
|
||||
go 1.24
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/containerd/console v1.0.3
|
||||
@@ -11,7 +11,7 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/sync v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -69,12 +69,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.28.0
|
||||
golang.org/x/term v0.27.0
|
||||
golang.org/x/text v0.21.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/term v0.29.0
|
||||
golang.org/x/text v0.22.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
28
go.sum
28
go.sum
@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -29,6 +29,17 @@ type Cache interface {
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
|
||||
@@ -22,6 +22,9 @@ type Causal struct {
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding == 0 {
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
c.DType = dtype
|
||||
c.Capacity = capacity
|
||||
c.cells = make([]cacheCell, capacity)
|
||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
return (length / pad) * pad
|
||||
}
|
||||
|
||||
func roundUp(length, pad int) int {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||
// TODO(jessegross): This does not do padding, which is required for flash attention
|
||||
len := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, c.curBatchSize*len)
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||
c.cells[j].pos < positions[i]-c.windowSize {
|
||||
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||
for _, obj := range objs {
|
||||
if obj == nil {
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
for i := range c.keys {
|
||||
if c.keys[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
||||
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
||||
key := c.keys[i]
|
||||
|
||||
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
@@ -305,33 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
cachedSize := c.curMask.Dim(0)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
|
||||
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
value.Dim(0), value.Stride(1),
|
||||
value.Dim(1), value.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
)
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, value.Stride(1),
|
||||
vHeadDim, value.Stride(2),
|
||||
numKVHeads,
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, value.Stride(1),
|
||||
numKVHeads, value.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
if c.curBatchSize != key.Dim(2) {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
||||
kHeadDim := key.Dim(0)
|
||||
vHeadDim := value.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
|
||||
if c.curBatchSize != batchSize {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
@@ -387,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
continue
|
||||
}
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
|
||||
|
||||
@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
|
||||
context.Forward(out)
|
||||
context.Forward(mask)
|
||||
context.Compute(out, mask)
|
||||
context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||
@@ -311,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
if len(shape) > 0 {
|
||||
@@ -324,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.Empty(dtype, shape...)
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
@@ -344,7 +346,7 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Forward(ml.Tensor) {}
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
@@ -393,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
for i := range out.data {
|
||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
@@ -470,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
|
||||
context := &testContext{}
|
||||
|
||||
view := context.Zeros(t.dtype, s...).(*testTensor)
|
||||
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
return view
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,6 +13,9 @@ import (
|
||||
//
|
||||
// Not currently safe for multiple sequences
|
||||
type EncoderCache struct {
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -75,13 +100,19 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
key.Copy(ctx, c.keys[c.curLayer]),
|
||||
value.Copy(ctx, c.values[c.curLayer]),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
|
||||
@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
for _, cache := range c.caches {
|
||||
cache.SetConfig(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Close() {
|
||||
for _, cache := range c.caches {
|
||||
cache.Close()
|
||||
|
||||
135
llama/grammar.go
Normal file
135
llama/grammar.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package llama
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -std=c11
|
||||
#cgo CXXFLAGS: -std=c++17
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/include
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/common
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/../llama/llama.cpp/src
|
||||
#cgo CPPFLAGS: -I${SRCDIR}
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <stdbool.h>
|
||||
#include "llama.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
// Helper function to handle Go string arrays to C
|
||||
static char** makeCharArray(int size) {
|
||||
return (char**)malloc(size * sizeof(char*));
|
||||
}
|
||||
|
||||
static void setArrayString(char** a, int i, const char* s) {
|
||||
a[i] = (char*)s;
|
||||
}
|
||||
|
||||
static void freeCharArray(char** a, int size) {
|
||||
free(a);
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// Grammar represents the interface for grammar-based sampling
|
||||
type Grammar interface {
|
||||
Apply(logits []float32) ([]float32, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// CGrammar is a wrapper around the C++ grammar implementation
|
||||
type CGrammar struct {
|
||||
grammar *C.struct_llama_grammar
|
||||
model *C.struct_llama_model
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewGrammarWithTokens creates a new grammar using a custom vocabulary defined by tokens
|
||||
func NewGrammarWithTokens(grammarStr, grammarRoot string, tokens []string) (Grammar, error) {
|
||||
if grammarStr == "" {
|
||||
return nil, errors.New("empty grammar string")
|
||||
}
|
||||
|
||||
if len(tokens) == 0 {
|
||||
return nil, errors.New("empty token list")
|
||||
}
|
||||
|
||||
// Create C array of strings for tokens
|
||||
cTokens := C.makeCharArray(C.int(len(tokens)))
|
||||
defer C.freeCharArray(cTokens, C.int(len(tokens)))
|
||||
|
||||
// Convert Go strings to C strings and set them in the array
|
||||
cStrings := make([]*C.char, len(tokens))
|
||||
for i, token := range tokens {
|
||||
cStrings[i] = C.CString(token)
|
||||
C.setArrayString(cTokens, C.int(i), cStrings[i])
|
||||
}
|
||||
|
||||
// Create vocabulary from tokens
|
||||
cVocab := C.vocab_bridge_from_tokens((**C.char)(unsafe.Pointer(cTokens)), C.int(len(tokens)))
|
||||
|
||||
// Free the C strings after creating the vocab
|
||||
for _, str := range cStrings {
|
||||
C.free(unsafe.Pointer(str))
|
||||
}
|
||||
|
||||
if cVocab == nil {
|
||||
return nil, errors.New("failed to create vocabulary from tokens")
|
||||
}
|
||||
|
||||
// Make sure to free the vocabulary when we're done
|
||||
defer C.vocab_bridge_free(cVocab)
|
||||
|
||||
cGrammarStr := C.CString(grammarStr)
|
||||
defer C.free(unsafe.Pointer(cGrammarStr))
|
||||
|
||||
cGrammarRoot := C.CString(grammarRoot)
|
||||
defer C.free(unsafe.Pointer(cGrammarRoot))
|
||||
|
||||
// Create grammar using our C wrapper function with the correct signature
|
||||
grammar := C.grammar_create_from_string(cVocab, cGrammarStr, cGrammarRoot)
|
||||
if grammar == nil {
|
||||
return nil, errors.New("failed to initialize grammar")
|
||||
}
|
||||
|
||||
cg := &CGrammar{
|
||||
grammar: grammar,
|
||||
closed: false,
|
||||
}
|
||||
|
||||
// Set up finalizer to free resources when the object is garbage collected
|
||||
runtime.SetFinalizer(cg, func(g *CGrammar) {
|
||||
g.Close()
|
||||
})
|
||||
|
||||
return cg, nil
|
||||
}
|
||||
|
||||
// Apply applies grammar constraints to logits
|
||||
func (g *CGrammar) Apply(logits []float32) ([]float32, error) {
|
||||
if g.closed || g.grammar == nil {
|
||||
return nil, errors.New("grammar not initialized or already closed")
|
||||
}
|
||||
|
||||
// Create a copy of logits to modify
|
||||
result := make([]float32, len(logits))
|
||||
copy(result, logits)
|
||||
|
||||
// Apply grammar constraints using our C wrapper function
|
||||
C.grammar_apply_to_logits(g.grammar, (*C.float)(&result[0]), C.int(len(result)))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Close releases resources associated with the grammar
|
||||
func (g *CGrammar) Close() error {
|
||||
if !g.closed && g.grammar != nil {
|
||||
C.grammar_free(g.grammar)
|
||||
g.grammar = nil
|
||||
g.closed = true
|
||||
}
|
||||
return nil
|
||||
}
|
||||
83
llama/grammar_ext.cpp
vendored
Normal file
83
llama/grammar_ext.cpp
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
#include <stdlib.h>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cstdint>
|
||||
#include <stdexcept>
|
||||
|
||||
#include "llama-sampling.h"
|
||||
#include "llama-grammar.h"
|
||||
#include "llama-vocab.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
extern "C" {
|
||||
|
||||
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root) {
|
||||
try {
|
||||
// Initialize grammar sampler directly with the model
|
||||
struct llama_sampler* sampler = llama_sampler_init_grammar(vocab, grammar_str, grammar_root);
|
||||
if (!sampler) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Cast the sampler to a grammar and return it
|
||||
return (struct llama_grammar*)sampler;
|
||||
} catch (const std::exception &err) {
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits) {
|
||||
if (!grammar || !logits || n_logits <= 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Create token data array for the grammar application
|
||||
llama_token_data* token_data = (llama_token_data*)malloc(n_logits * sizeof(llama_token_data));
|
||||
if (!token_data) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Initialize token data from logits
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
token_data[i].id = i;
|
||||
token_data[i].logit = logits[i];
|
||||
token_data[i].p = 0.0f;
|
||||
}
|
||||
|
||||
// Create token data array structure
|
||||
llama_token_data_array arr = {
|
||||
.data = token_data,
|
||||
.size = (size_t)n_logits,
|
||||
.sorted = false,
|
||||
.selected = -1
|
||||
};
|
||||
|
||||
// Apply grammar constraints to the token data array
|
||||
llama_grammar_apply_impl(*grammar, &arr);
|
||||
|
||||
// Copy back the modified logits
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
logits[i] = token_data[i].logit;
|
||||
}
|
||||
|
||||
free(token_data);
|
||||
}
|
||||
|
||||
void grammar_free(struct llama_grammar* grammar) {
|
||||
if (grammar) {
|
||||
// Free the grammar as a sampler
|
||||
llama_sampler_free((struct llama_sampler*)grammar);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens) {
|
||||
// Call the C++ function from llama-vocab.cpp
|
||||
return llama_vocab_from_tokens(tokens, n_tokens);
|
||||
}
|
||||
|
||||
void vocab_bridge_free(struct llama_vocab* vocab) {
|
||||
// Call the C++ function from llama-vocab.cpp
|
||||
llama_vocab_free(vocab);
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
33
llama/grammar_ext.h
vendored
Normal file
33
llama/grammar_ext.h
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
#ifndef GRAMMAR_EXT_H
|
||||
#define GRAMMAR_EXT_H
|
||||
|
||||
#include "llama.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
// Forward declarations
|
||||
struct llama_grammar;
|
||||
struct llama_vocab;
|
||||
|
||||
// Create a new grammar from a string (returns a grammar implemented as a sampler)
|
||||
struct llama_grammar* grammar_create_from_string(const struct llama_vocab* vocab, const char* grammar_str, const char* grammar_root);
|
||||
|
||||
// Apply grammar constraints to logits
|
||||
void grammar_apply_to_logits(struct llama_grammar* grammar, float* logits, int n_logits);
|
||||
|
||||
// Free grammar resources (frees the underlying sampler)
|
||||
void grammar_free(struct llama_grammar* grammar);
|
||||
|
||||
// C wrapper for llama_vocab_from_tokens
|
||||
struct llama_vocab* vocab_bridge_from_tokens(const char** tokens, int n_tokens);
|
||||
|
||||
// C wrapper for llama_vocab_free
|
||||
void vocab_bridge_free(struct llama_vocab* vocab);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif // GRAMMAR_EXT_H
|
||||
1
llama/llama.cpp/include/llama.h
vendored
1
llama/llama.cpp/include/llama.h
vendored
@@ -105,6 +105,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
||||
10
llama/llama.cpp/src/llama-model.cpp
vendored
10
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
|
||||
11
llama/llama.cpp/src/llama-vocab.cpp
vendored
11
llama/llama.cpp/src/llama-vocab.cpp
vendored
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
||||
// original regex from tokenizer.json
|
||||
// [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
|
||||
regex_exprs = {
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
@@ -18,80 +18,49 @@ package llama
|
||||
|
||||
#include "mllama.h"
|
||||
#include "sampling_ext.h"
|
||||
#include "grammar_ext.h"
|
||||
|
||||
extern bool llamaProgressCallback(float progress, void *user_data);
|
||||
extern void llamaLog(int level, char* text, void* user_data);
|
||||
|
||||
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||
COMPILER inline get_compiler() {
|
||||
#if defined(__clang__)
|
||||
return COMP_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMP_GCC;
|
||||
#else
|
||||
return UNKNOWN_COMPILER;
|
||||
#endif
|
||||
}
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
func init() {
|
||||
C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
|
||||
}
|
||||
|
||||
//export llamaLog
|
||||
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
|
||||
// slog levels zeros INFO and are multiples of 4
|
||||
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
}
|
||||
|
||||
func BackendInit() {
|
||||
ggml.OnceLoad()
|
||||
C.llama_backend_init()
|
||||
}
|
||||
|
||||
func PrintSystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
case C.COMP_UNKNOWN:
|
||||
compiler = "cgo(unknown_compiler)"
|
||||
case C.COMP_GCC:
|
||||
compiler = "cgo(gcc)"
|
||||
case C.COMP_CLANG:
|
||||
compiler = "cgo(clang)"
|
||||
}
|
||||
return C.GoString(C.llama_print_system_info()) + compiler
|
||||
}
|
||||
|
||||
var logLevel atomic.Int32
|
||||
|
||||
func init() {
|
||||
logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
|
||||
C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
|
||||
}
|
||||
|
||||
func EnableDebug() {
|
||||
logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
|
||||
}
|
||||
|
||||
//export llamaLog
|
||||
func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
|
||||
if level < logLevel.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
|
||||
func GetModelArch(modelPath string) (string, error) {
|
||||
mp := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
@@ -269,7 +238,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
|
||||
}
|
||||
|
||||
m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
|
||||
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
|
||||
if m.c == nil {
|
||||
return nil, fmt.Errorf("unable to load model: %s", modelPath)
|
||||
}
|
||||
@@ -278,12 +247,12 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
}
|
||||
|
||||
func FreeModel(model *Model) {
|
||||
C.llama_free_model(model.c)
|
||||
C.llama_model_free(model.c)
|
||||
}
|
||||
|
||||
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
|
||||
c := Context{
|
||||
c: C.llama_new_context_with_model(model.c, params.c),
|
||||
c: C.llama_init_from_model(model.c, params.c),
|
||||
numThreads: int(params.c.n_threads),
|
||||
}
|
||||
if c.c == nil {
|
||||
@@ -294,15 +263,15 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
|
||||
}
|
||||
|
||||
func (m *Model) NumVocab() int {
|
||||
return int(C.llama_n_vocab(m.Vocab()))
|
||||
return int(C.llama_vocab_n_tokens(m.Vocab()))
|
||||
}
|
||||
|
||||
func (m *Model) TokenIsEog(token int) bool {
|
||||
return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
|
||||
return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
|
||||
}
|
||||
|
||||
func (m *Model) AddBOSToken() bool {
|
||||
return bool(C.llama_add_bos_token(m.Vocab()))
|
||||
return bool(C.llama_vocab_get_add_bos(m.Vocab()))
|
||||
}
|
||||
|
||||
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
|
||||
@@ -485,7 +454,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
|
||||
}
|
||||
|
||||
func (m *Model) NEmbd() int {
|
||||
return int(C.llama_n_embd(m.c))
|
||||
return int(C.llama_model_n_embd(m.c))
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile string, ftype uint32) error {
|
||||
|
||||
@@ -1,69 +0,0 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: Michael Yang <mxyng@pm.me>
|
||||
Date: Tue, 11 Feb 2025 14:06:36 -0800
|
||||
Subject: [PATCH] try/catch backend load
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++-----------------
|
||||
1 file changed, 23 insertions(+), 22 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 98d5e14d..1c19129a 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
- if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
- if (!handle && !silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
- if (handle) {
|
||||
+ try {
|
||||
+ if (entry.is_regular_file()) {
|
||||
+ std::wstring filename = entry.path().filename().wstring();
|
||||
+ std::wstring ext = entry.path().extension().wstring();
|
||||
+ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ if (!handle) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
- if (score_fn) {
|
||||
- int s = score_fn();
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
-#endif
|
||||
- if (s > best_score) {
|
||||
- best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
- }
|
||||
- } else {
|
||||
- if (!silent) {
|
||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
+ if (!score_fn) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ int s = score_fn();
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ if (s > best_score) {
|
||||
+ best_score = s;
|
||||
+ best_path = entry.path().wstring();
|
||||
}
|
||||
}
|
||||
}
|
||||
+ } catch (const std::exception & e) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,11 +4,11 @@ Date: Sun, 16 Feb 2025 20:00:22 -0500
|
||||
Subject: [PATCH] use std::filesystem::path instead of wstring
|
||||
|
||||
---
|
||||
ggml/src/ggml-backend-reg.cpp | 144 ++++++++++++++--------------------
|
||||
1 file changed, 58 insertions(+), 86 deletions(-)
|
||||
ggml/src/ggml-backend-reg.cpp | 199 +++++++++++++++-------------------
|
||||
1 file changed, 88 insertions(+), 111 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp
|
||||
index 1c19129a..c854e6bb 100644
|
||||
index 98d5e14d..799af5f3 100644
|
||||
--- a/ggml/src/ggml-backend-reg.cpp
|
||||
+++ b/ggml/src/ggml-backend-reg.cpp
|
||||
@@ -66,26 +66,6 @@
|
||||
@@ -264,47 +264,55 @@ index 1c19129a..c854e6bb 100644
|
||||
for (const auto & search_path : search_paths) {
|
||||
if (!fs::exists(search_path)) {
|
||||
continue;
|
||||
@@ -514,31 +486,31 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
@@ -513,29 +485,26 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
+ std::string filename = entry.path().filename().string();
|
||||
+ std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
- GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
+ best_path = entry.path();
|
||||
}
|
||||
if (entry.is_regular_file()) {
|
||||
- std::wstring filename = entry.path().filename().wstring();
|
||||
- std::wstring ext = entry.path().extension().wstring();
|
||||
+ std::string filename = entry.path().filename().string();
|
||||
+ std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) };
|
||||
- if (!handle && !silent) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
+ dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
+ if (!handle) {
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
+ continue;
|
||||
}
|
||||
- if (handle) {
|
||||
- auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
- if (score_fn) {
|
||||
- int s = score_fn();
|
||||
-#ifndef NDEBUG
|
||||
- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s);
|
||||
-#endif
|
||||
- if (s > best_score) {
|
||||
- best_score = s;
|
||||
- best_path = entry.path().wstring();
|
||||
- }
|
||||
- } else {
|
||||
- if (!silent) {
|
||||
- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str());
|
||||
- }
|
||||
- }
|
||||
+
|
||||
+ auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
+ if (!score_fn) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
+ continue;
|
||||
+ }
|
||||
+
|
||||
+ int s = score_fn();
|
||||
+ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
+ if (s > best_score) {
|
||||
+ best_score = s;
|
||||
+ best_path = entry.path();
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
- GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what());
|
||||
+ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -546,7 +518,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
@@ -545,7 +514,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
if (best_score == 0) {
|
||||
// try to load the base backend
|
||||
for (const auto & search_path : search_paths) {
|
||||
@@ -313,3 +321,49 @@ index 1c19129a..c854e6bb 100644
|
||||
if (fs::exists(path)) {
|
||||
return get_reg().load_backend(path, silent);
|
||||
}
|
||||
@@ -560,6 +529,14 @@ void ggml_backend_load_all() {
|
||||
ggml_backend_load_all_from_path(nullptr);
|
||||
}
|
||||
|
||||
+static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
+ try {
|
||||
+ ggml_backend_load_best(name, silent, user_search_path);
|
||||
+ } catch (const std::exception & e) {
|
||||
+ GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
#ifdef NDEBUG
|
||||
bool silent = true;
|
||||
@@ -567,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
bool silent = false;
|
||||
#endif
|
||||
|
||||
- ggml_backend_load_best("blas", silent, dir_path);
|
||||
- ggml_backend_load_best("cann", silent, dir_path);
|
||||
- ggml_backend_load_best("cuda", silent, dir_path);
|
||||
- ggml_backend_load_best("hip", silent, dir_path);
|
||||
- ggml_backend_load_best("kompute", silent, dir_path);
|
||||
- ggml_backend_load_best("metal", silent, dir_path);
|
||||
- ggml_backend_load_best("rpc", silent, dir_path);
|
||||
- ggml_backend_load_best("sycl", silent, dir_path);
|
||||
- ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
- ggml_backend_load_best("opencl", silent, dir_path);
|
||||
- ggml_backend_load_best("musa", silent, dir_path);
|
||||
- ggml_backend_load_best("cpu", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("blas", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cann", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cuda", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("hip", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("kompute", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("metal", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("rpc", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("sycl", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("vulkan", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("opencl", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("musa", silent, dir_path);
|
||||
+ ggml_backend_try_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||
if (backend_path) {
|
||||
80
llama/patches/0018-add-phi4-support.patch
Normal file
80
llama/patches/0018-add-phi4-support.patch
Normal file
@@ -0,0 +1,80 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Thu, 27 Feb 2025 15:12:26 -0800
|
||||
Subject: [PATCH] add phi4 support
|
||||
|
||||
---
|
||||
include/llama.h | 1 +
|
||||
src/llama-model.cpp | 10 +++++++---
|
||||
src/llama-vocab.cpp | 11 +++++++++++
|
||||
3 files changed, 19 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/include/llama.h b/include/llama.h
|
||||
index cc948005..16774711 100644
|
||||
--- a/include/llama.h
|
||||
+++ b/include/llama.h
|
||||
@@ -105,6 +105,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 21819080..ab1a07d1 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
+ // if output is NULL, init from the input tok embed
|
||||
+ if (output == NULL) {
|
||||
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
+ }
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
- layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
- layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
+ layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
+ layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||
index 1ca827eb..c7ff28be 100644
|
||||
--- a/src/llama-vocab.cpp
|
||||
+++ b/src/llama-vocab.cpp
|
||||
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
};
|
||||
break;
|
||||
+ case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
||||
+ // original regex from tokenizer.json
|
||||
+ // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
|
||||
+ regex_exprs = {
|
||||
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
+ };
|
||||
+ break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
+ } else if (
|
||||
+ tokenizer_pre == "gpt-4o") {
|
||||
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
+ clean_spaces = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
117
llama/patches/0019-expose-llama_vocab-from-tokens.patch
Normal file
@@ -0,0 +1,117 @@
|
||||
From 668a974433edccf2c5fcc2192c39aed601e575f2 Mon Sep 17 00:00:00 2001
|
||||
From: Bruce MacDonald <brucewmacdonald@gmail.com>
|
||||
Date: Thu, 6 Mar 2025 21:07:06 -0800
|
||||
Subject: [PATCH] expose llama_vocab from tokens
|
||||
|
||||
---
|
||||
llama/llama.cpp/src/llama-vocab.cpp | 73 +++++++++++++++++++++++++++++
|
||||
llama/llama.cpp/src/llama-vocab.h | 11 ++++-
|
||||
2 files changed, 83 insertions(+), 1 deletion(-)
|
||||
|
||||
diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp
|
||||
index c7ff28be..ad6e7ad8 100644
|
||||
--- a/llama/llama.cpp/src/llama-vocab.cpp
|
||||
+++ b/llama/llama.cpp/src/llama-vocab.cpp
|
||||
@@ -3253,3 +3253,76 @@ int32_t llama_detokenize(
|
||||
return vocab->detokenize(tokens, n_tokens, text, text_len_max, remove_special, unparse_special);
|
||||
}
|
||||
|
||||
+struct llama_vocab *llama_vocab_from_tokens(const char **tokens, int n_tokens)
|
||||
+{
|
||||
+ if (!tokens || n_tokens <= 0)
|
||||
+ {
|
||||
+ return nullptr;
|
||||
+ }
|
||||
+
|
||||
+ try
|
||||
+ {
|
||||
+ // Create a new vocabulary instance
|
||||
+ llama_vocab *vocab = new llama_vocab();
|
||||
+ vocab->pimpl = std::make_unique<llama_vocab::impl>(*vocab);
|
||||
+
|
||||
+ // Resize the token data vectors
|
||||
+ vocab->pimpl->id_to_token.resize(n_tokens);
|
||||
+
|
||||
+ // Create mappings for all tokens
|
||||
+ for (int i = 0; i < n_tokens; i++)
|
||||
+ {
|
||||
+ std::string word = tokens[i];
|
||||
+ if (word.empty())
|
||||
+ {
|
||||
+ word = "[EMPTY_" + std::to_string(i) + "]";
|
||||
+ }
|
||||
+
|
||||
+ // Add to token mappings
|
||||
+ vocab->pimpl->token_to_id[word] = i;
|
||||
+
|
||||
+ // Set up token data
|
||||
+ auto &token_data = vocab->pimpl->id_to_token[i];
|
||||
+ token_data.text = std::move(word);
|
||||
+ token_data.score = 0.0f; // Default score
|
||||
+ token_data.attr = LLAMA_TOKEN_ATTR_NORMAL;
|
||||
+
|
||||
+ // Detect special tokens
|
||||
+ if (word == "<s>" || word == "<bos>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_bos_id = i;
|
||||
+ }
|
||||
+ else if (word == "</s>" || word == "<eos>" || word == "<|endoftext|>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_eos_id = i;
|
||||
+ vocab->pimpl->special_eog_ids.insert(i);
|
||||
+ }
|
||||
+ else if (word == "<unk>")
|
||||
+ {
|
||||
+ vocab->pimpl->special_unk_id = i;
|
||||
+ }
|
||||
+ }
|
||||
+
|
||||
+ // Initialize the token-to-piece cache
|
||||
+ vocab->pimpl->cache_token_to_piece.resize(n_tokens);
|
||||
+ for (int i = 0; i < n_tokens; i++)
|
||||
+ {
|
||||
+ vocab->pimpl->cache_token_to_piece[i] = vocab->pimpl->id_to_token[i].text;
|
||||
+ }
|
||||
+
|
||||
+ return vocab;
|
||||
+ }
|
||||
+ catch (const std::exception &err)
|
||||
+ {
|
||||
+ return nullptr;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+// Helper function to free the vocab
|
||||
+void llama_vocab_free(struct llama_vocab *vocab)
|
||||
+{
|
||||
+ if (vocab)
|
||||
+ {
|
||||
+ delete vocab;
|
||||
+ }
|
||||
+}
|
||||
\ No newline at end of file
|
||||
diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h
|
||||
index 5ce35521..eceb28f3 100644
|
||||
--- a/llama/llama.cpp/src/llama-vocab.h
|
||||
+++ b/llama/llama.cpp/src/llama-vocab.h
|
||||
@@ -119,7 +119,16 @@ struct llama_vocab {
|
||||
|
||||
void print_info() const;
|
||||
|
||||
-private:
|
||||
struct impl;
|
||||
std::unique_ptr<impl> pimpl;
|
||||
};
|
||||
+
|
||||
+// Create a vocabulary from an array of token strings
|
||||
+// tokens: Array of token strings
|
||||
+// n_tokens: Number of tokens in the array
|
||||
+// Returns: A new llama_vocab instance, or nullptr on failure
|
||||
+// The caller is responsible for freeing the vocabulary using llama_vocab_free
|
||||
+LLAMA_API struct llama_vocab * llama_vocab_from_tokens(const char ** tokens, int n_tokens);
|
||||
+
|
||||
+// Free a vocabulary created with llama_vocab_from_tokens
|
||||
+LLAMA_API void llama_vocab_free(struct llama_vocab * vocab);
|
||||
--
|
||||
2.39.3 (Apple Git-145)
|
||||
|
||||
@@ -115,6 +115,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
if projectorWeights == 0 && projectorGraph == 0 {
|
||||
projectorWeights, projectorGraph = f.VisionGraphSize()
|
||||
}
|
||||
|
||||
layers := f.Tensors().GroupLayers()
|
||||
// add one layer worth of memory as a buffer
|
||||
|
||||
278
llm/server.go
278
llm/server.go
@@ -30,6 +30,7 @@ import (
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type LlamaServer interface {
|
||||
@@ -54,8 +55,15 @@ type llmServer struct {
|
||||
options api.Options
|
||||
numParallel int
|
||||
modelPath string
|
||||
modelLock sync.Mutex // Temporary until we switch fully to Go server
|
||||
model *llama.Model // If non-nil, the runner is a new Go server
|
||||
|
||||
// llamaModel is an instance of the cgo llama.cpp model definition
|
||||
// nil if this server is running the new engine
|
||||
llamaModel *llama.Model
|
||||
llamaModelLock sync.Mutex
|
||||
|
||||
// textProcessor handles text encoding/decoding for the model in the Ollama engine
|
||||
// nil if this server is running the llama.cpp based engine
|
||||
textProcessor model.TextProcessor
|
||||
|
||||
estimate MemoryEstimate
|
||||
totalLayers uint64
|
||||
@@ -89,7 +97,7 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) {
|
||||
|
||||
// NewLlamaServer will run a server for the given GPUs
|
||||
// The gpu list must be a single family.
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) {
|
||||
systemInfo := discover.GetSystemInfo()
|
||||
systemTotalMemory := systemInfo.System.TotalMemory
|
||||
systemFreeMemory := systemInfo.System.FreeMemory
|
||||
@@ -130,7 +138,7 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
slog.Info("offload", "", estimate)
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--model", modelPath,
|
||||
"--ctx-size", strconv.Itoa(opts.NumCtx),
|
||||
"--batch-size", strconv.Itoa(opts.NumBatch),
|
||||
}
|
||||
@@ -153,11 +161,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
defaultThreads := systemInfo.GetOptimalThreadCount()
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", strconv.Itoa(opts.NumThread))
|
||||
@@ -257,6 +260,34 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
}
|
||||
slog.Debug("compatible gpu libraries", "compatible", compatible)
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
var llamaModel *llama.Model
|
||||
var textProcessor model.TextProcessor
|
||||
if envconfig.NewEngine() {
|
||||
textProcessor, err = model.NewTextProcessor(modelPath)
|
||||
if err != nil {
|
||||
// To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner
|
||||
slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err)
|
||||
}
|
||||
}
|
||||
if textProcessor == nil {
|
||||
llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 && llamaModel != nil {
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
// iterate through compatible GPU libraries such as 'cuda_v12', 'cuda_v11', 'rocm', etc.
|
||||
// adding each library's respective path to the LD_LIBRARY_PATH, until finally running
|
||||
@@ -275,7 +306,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := []string{"runner"}
|
||||
if envconfig.NewEngine() {
|
||||
if textProcessor != nil {
|
||||
// New engine
|
||||
// TODO - if we have failure to load scenarios, add logic to retry with the old runner
|
||||
finalParams = append(finalParams, "--ollama-engine")
|
||||
}
|
||||
finalParams = append(finalParams, params...)
|
||||
@@ -315,28 +348,20 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
// finally, add the root library path
|
||||
libraryPaths = append(libraryPaths, discover.LibOllamaPath)
|
||||
|
||||
exe, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to lookup executable path: %w", err)
|
||||
}
|
||||
|
||||
if eval, err := filepath.EvalSymlinks(exe); err == nil {
|
||||
exe = eval
|
||||
}
|
||||
|
||||
// TODO - once fully switched to the Go runner, load the model here for tokenize/detokenize cgo access
|
||||
s := &llmServer{
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: model,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
port: port,
|
||||
cmd: exec.Command(exe, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
modelPath: modelPath,
|
||||
llamaModel: llamaModel,
|
||||
textProcessor: textProcessor,
|
||||
estimate: estimate,
|
||||
numParallel: numParallel,
|
||||
sem: semaphore.NewWeighted(int64(numParallel)),
|
||||
totalLayers: f.KV().BlockCount() + 1,
|
||||
gpus: gpus,
|
||||
done: make(chan error, 1),
|
||||
}
|
||||
|
||||
s.cmd.Env = os.Environ()
|
||||
@@ -405,6 +430,9 @@ func NewLlamaServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapt
|
||||
}
|
||||
err := fmt.Errorf("error starting runner: %v %s", err, msg)
|
||||
if len(compatible) == 0 {
|
||||
if llamaModel != nil {
|
||||
llama.FreeModel(llamaModel)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -701,24 +729,29 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu
|
||||
}
|
||||
|
||||
if len(req.Format) > 0 {
|
||||
switch string(req.Format) {
|
||||
case `null`, `""`:
|
||||
// Field was set, but "missing" a value. We accept
|
||||
// these as "not set".
|
||||
break
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
format := string(req.Format)
|
||||
if format != `null` && format != `""` {
|
||||
if s.textProcessor != nil {
|
||||
// New engine handles this on the backend
|
||||
request["format"] = req.Format
|
||||
} else {
|
||||
// old engine
|
||||
switch format {
|
||||
case `"json"`:
|
||||
request["grammar"] = grammarJSON
|
||||
default:
|
||||
if req.Format[0] != '{' {
|
||||
return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format)
|
||||
}
|
||||
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
// User provided a JSON schema
|
||||
g := llama.SchemaToGrammar(req.Format)
|
||||
if g == nil {
|
||||
return fmt.Errorf("invalid JSON schema in format")
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
request["grammar"] = string(g)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -933,64 +966,25 @@ type TokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
return s.model.Tokenize(content, false, true)
|
||||
}
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return nil, fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
if s.llamaModel != nil {
|
||||
return s.llamaModel.Tokenize(content, false, true)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
tokens, err := s.textProcessor.Encode(content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s.model.Tokenize(content, false, true)
|
||||
toks := make([]int, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int(t)
|
||||
}
|
||||
return toks, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
// not reached
|
||||
return nil, fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
@@ -1002,80 +996,38 @@ type DetokenizeResponse struct {
|
||||
}
|
||||
|
||||
func (s *llmServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
s.modelLock.Lock()
|
||||
defer s.modelLock.Unlock()
|
||||
if s.model != nil {
|
||||
s.llamaModelLock.Lock()
|
||||
defer s.llamaModelLock.Unlock()
|
||||
|
||||
if s.llamaModel != nil {
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
resp += s.llamaModel.TokenToPiece(token)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady && status != ServerStatusNoSlotsAvailable {
|
||||
return "", fmt.Errorf("unexpected server status: %s", status.ToString())
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
if s.model == nil {
|
||||
slog.Debug("new runner detected, loading model for cgo tokenization")
|
||||
m, err := llama.LoadModelFromFile(s.modelPath, llama.ModelParams{VocabOnly: true})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
s.model = m
|
||||
if s.textProcessor != nil {
|
||||
toks := make([]int32, len(tokens))
|
||||
for i, t := range tokens {
|
||||
toks[i] = int32(t)
|
||||
}
|
||||
var resp string
|
||||
for _, token := range tokens {
|
||||
resp += s.model.TokenToPiece(token)
|
||||
content, err := s.textProcessor.Decode(toks)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return resp, nil
|
||||
return content, nil
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
// not reached
|
||||
return "", fmt.Errorf("no tokenizer configured")
|
||||
}
|
||||
|
||||
func (s *llmServer) Close() error {
|
||||
s.modelLock.Lock()
|
||||
if s.model != nil {
|
||||
llama.FreeModel(s.model)
|
||||
s.model = nil
|
||||
s.llamaModelLock.Lock()
|
||||
if s.llamaModel != nil {
|
||||
llama.FreeModel(s.llamaModel)
|
||||
s.llamaModel = nil
|
||||
}
|
||||
s.modelLock.Unlock()
|
||||
s.llamaModelLock.Unlock()
|
||||
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
|
||||
@@ -14,6 +14,7 @@ type Config interface {
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
Bool(string, ...bool) bool
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
@@ -23,7 +24,35 @@ type Backend interface {
|
||||
Config() Config
|
||||
Get(name string) Tensor
|
||||
NewContext() Context
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
@@ -39,6 +68,9 @@ type BackendParams struct {
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
@@ -60,11 +92,12 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
Close()
|
||||
@@ -115,6 +148,10 @@ type Tensor interface {
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
@@ -169,7 +206,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16:
|
||||
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
@@ -185,8 +222,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
|
||||
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
if t.Bytes() == nil {
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
|
||||
@@ -1,27 +1,11 @@
|
||||
package ggml
|
||||
|
||||
/*
|
||||
#cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
#include <stdlib.h>
|
||||
#include <stdint.h>
|
||||
#include "ggml.h"
|
||||
#include "ggml-cpu.h"
|
||||
#include "ggml-backend.h"
|
||||
static struct ggml_backend_feature * getBackendFeatures(void *fp, ggml_backend_reg_t reg) {return ((ggml_backend_get_features_t)(fp))(reg);}
|
||||
static struct ggml_backend_feature * getNextBackendFeatures(struct ggml_backend_feature * feature) { return &feature[1];}
|
||||
|
||||
typedef enum {COMP_UNKNOWN,COMP_GCC,COMP_CLANG} COMPILER;
|
||||
COMPILER inline get_compiler() {
|
||||
#if defined(__clang__)
|
||||
return COMP_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMP_GCC;
|
||||
#else
|
||||
return UNKNOWN_COMPILER;
|
||||
#endif
|
||||
}
|
||||
|
||||
*/
|
||||
// #cgo CPPFLAGS: -I${SRCDIR}/ggml/include
|
||||
// #include <stdlib.h>
|
||||
// #include <stdint.h>
|
||||
// #include "ggml.h"
|
||||
// #include "ggml-cpu.h"
|
||||
// #include "ggml-backend.h"
|
||||
import "C"
|
||||
|
||||
import (
|
||||
@@ -79,6 +63,8 @@ var devices = sync.OnceValue(func() []device {
|
||||
})
|
||||
|
||||
type Backend struct {
|
||||
flashAttention bool
|
||||
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
@@ -192,9 +178,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
@@ -219,7 +206,7 @@ func (b *Backend) Get(name string) ml.Tensor {
|
||||
|
||||
for _, c := range append(b.gpus, b.cpus...) {
|
||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||
return &Tensor{t: t}
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +234,14 @@ func (b *Backend) NewContext() ml.Context {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
b *Backend
|
||||
ctx *C.struct_ggml_context
|
||||
@@ -256,12 +251,16 @@ type Context struct {
|
||||
nodes int
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
}
|
||||
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
for _, tensor := range tensors {
|
||||
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
@@ -296,7 +295,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
||||
return &sh[0]
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||
if len(shape) < 1 || len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -310,19 +309,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeF16:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_set_zero(t)
|
||||
return &Tensor{t: t}
|
||||
if zero {
|
||||
C.ggml_set_zero(t)
|
||||
}
|
||||
return &Tensor{b: ctx.b, t: t}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, false, shape)
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, true, shape)
|
||||
}
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
@@ -331,7 +340,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
if n == 0 {
|
||||
var shape C.int64_t = 0
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
@@ -346,7 +355,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
@@ -364,6 +373,7 @@ func (c *Context) Close() {
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
b *Backend
|
||||
t *C.struct_ggml_tensor
|
||||
sync func()
|
||||
}
|
||||
@@ -430,6 +440,7 @@ func (t *Tensor) DType() ml.DType {
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -444,24 +455,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -471,12 +486,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
}
|
||||
@@ -485,7 +501,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
@@ -494,6 +510,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -504,18 +521,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -524,18 +544,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
}
|
||||
default:
|
||||
@@ -545,18 +569,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
@@ -567,6 +594,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -575,10 +603,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||
C.size_t(shape[1]),
|
||||
@@ -586,6 +616,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 5:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||
@@ -593,6 +624,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 7:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||
@@ -609,7 +641,7 @@ const (
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{}
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
@@ -618,6 +650,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
@@ -635,18 +668,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
}
|
||||
}
|
||||
@@ -657,42 +693,23 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, t)
|
||||
kq = &Tensor{
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
if t.b.flashAttention {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
var compiler string
|
||||
switch C.get_compiler() {
|
||||
case C.COMP_UNKNOWN:
|
||||
compiler = "cgo(unknown_compiler)"
|
||||
case C.COMP_GCC:
|
||||
compiler = "cgo(gcc)"
|
||||
case C.COMP_CLANG:
|
||||
compiler = "cgo(clang)"
|
||||
}
|
||||
|
||||
var s string
|
||||
for i := range C.ggml_backend_reg_count() {
|
||||
reg := C.ggml_backend_reg_get(i)
|
||||
fName := C.CString("ggml_backend_get_features")
|
||||
defer C.free(unsafe.Pointer(fName))
|
||||
get_features_fn := C.ggml_backend_reg_get_proc_address(reg, fName)
|
||||
if get_features_fn != nil {
|
||||
s += C.GoString(C.ggml_backend_reg_name(reg))
|
||||
s += " : "
|
||||
for features := C.getBackendFeatures(get_features_fn, reg); features.name != nil; features = C.getNextBackendFeatures(features) {
|
||||
s += C.GoString(features.name)
|
||||
s += " = "
|
||||
s += C.GoString(features.value)
|
||||
s += " | "
|
||||
}
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
return s + compiler
|
||||
}
|
||||
|
||||
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
74
ml/backend/ggml/ggml/src/ggml-backend-reg.cpp
vendored
@@ -484,33 +484,29 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
|
||||
}
|
||||
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
|
||||
for (const auto & entry : dir_it) {
|
||||
try {
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
if (entry.is_regular_file()) {
|
||||
std::string filename = entry.path().filename().string();
|
||||
std::string ext = entry.path().extension().string();
|
||||
if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) {
|
||||
dl_handle_ptr handle { dl_load_library(entry.path()) };
|
||||
if (!handle) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score");
|
||||
if (!score_fn) {
|
||||
GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, path_to_string(entry.path()).c_str());
|
||||
continue;
|
||||
}
|
||||
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path();
|
||||
}
|
||||
int s = score_fn();
|
||||
GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, path_to_string(entry.path()).c_str(), s);
|
||||
if (s > best_score) {
|
||||
best_score = s;
|
||||
best_path = entry.path();
|
||||
}
|
||||
}
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_to_string(entry.path()).c_str(), e.what());
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -533,6 +529,14 @@ void ggml_backend_load_all() {
|
||||
ggml_backend_load_all_from_path(nullptr);
|
||||
}
|
||||
|
||||
static void ggml_backend_try_load_best(const char * name, bool silent, const char * user_search_path) {
|
||||
try {
|
||||
ggml_backend_load_best(name, silent, user_search_path);
|
||||
} catch (const std::exception & e) {
|
||||
GGML_LOG_DEBUG("%s: failed to load %s: %s\n", __func__, name, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
#ifdef NDEBUG
|
||||
bool silent = true;
|
||||
@@ -540,18 +544,18 @@ void ggml_backend_load_all_from_path(const char * dir_path) {
|
||||
bool silent = false;
|
||||
#endif
|
||||
|
||||
ggml_backend_load_best("blas", silent, dir_path);
|
||||
ggml_backend_load_best("cann", silent, dir_path);
|
||||
ggml_backend_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_load_best("hip", silent, dir_path);
|
||||
ggml_backend_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_load_best("metal", silent, dir_path);
|
||||
ggml_backend_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_load_best("musa", silent, dir_path);
|
||||
ggml_backend_load_best("cpu", silent, dir_path);
|
||||
ggml_backend_try_load_best("blas", silent, dir_path);
|
||||
ggml_backend_try_load_best("cann", silent, dir_path);
|
||||
ggml_backend_try_load_best("cuda", silent, dir_path);
|
||||
ggml_backend_try_load_best("hip", silent, dir_path);
|
||||
ggml_backend_try_load_best("kompute", silent, dir_path);
|
||||
ggml_backend_try_load_best("metal", silent, dir_path);
|
||||
ggml_backend_try_load_best("rpc", silent, dir_path);
|
||||
ggml_backend_try_load_best("sycl", silent, dir_path);
|
||||
ggml_backend_try_load_best("vulkan", silent, dir_path);
|
||||
ggml_backend_try_load_best("opencl", silent, dir_path);
|
||||
ggml_backend_try_load_best("musa", silent, dir_path);
|
||||
ggml_backend_try_load_best("cpu", silent, dir_path);
|
||||
// check the environment variable GGML_BACKEND_PATH to load an out-of-tree backend
|
||||
const char * backend_path = std::getenv("GGML_BACKEND_PATH");
|
||||
if (backend_path) {
|
||||
|
||||
@@ -7,13 +7,30 @@ package ggml
|
||||
// #include <stdlib.h>
|
||||
// #include "ggml-backend.h"
|
||||
// extern void sink(int level, char *text, void *user_data);
|
||||
// static struct ggml_backend_feature * first_feature(ggml_backend_get_features_t fp, ggml_backend_reg_t reg) { return fp(reg); }
|
||||
// static struct ggml_backend_feature * next_feature(struct ggml_backend_feature * feature) { return &feature[1]; }
|
||||
/*
|
||||
typedef enum { COMPILER_CLANG, COMPILER_GNUC, COMPILER_UNKNOWN } COMPILER;
|
||||
static COMPILER compiler_name(void) {
|
||||
#if defined(__clang__)
|
||||
return COMPILER_CLANG;
|
||||
#elif defined(__GNUC__)
|
||||
return COMPILER_GNUC;
|
||||
#else
|
||||
return COMPILER_UNKNOWN;
|
||||
#endif
|
||||
}
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
@@ -22,21 +39,14 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
C.ggml_log_set((C.ggml_log_callback)(C.sink), nil)
|
||||
C.ggml_log_set(C.ggml_log_callback(C.sink), nil)
|
||||
}
|
||||
|
||||
//export sink
|
||||
func sink(level C.int, text *C.char, _ unsafe.Pointer) {
|
||||
msg := strings.TrimSpace(C.GoString(text))
|
||||
switch level {
|
||||
case C.GGML_LOG_LEVEL_DEBUG:
|
||||
slog.Debug(msg)
|
||||
case C.GGML_LOG_LEVEL_INFO:
|
||||
slog.Info(msg)
|
||||
case C.GGML_LOG_LEVEL_WARN:
|
||||
slog.Warn(msg)
|
||||
case C.GGML_LOG_LEVEL_ERROR:
|
||||
slog.Error(msg)
|
||||
// slog levels zeros INFO and are multiples of 4
|
||||
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,4 +105,43 @@ var OnceLoad = sync.OnceFunc(func() {
|
||||
visited[abspath] = struct{}{}
|
||||
}
|
||||
}
|
||||
|
||||
slog.Info("system", "", system{})
|
||||
})
|
||||
|
||||
type system struct{}
|
||||
|
||||
func (system) LogValue() slog.Value {
|
||||
var attrs []slog.Attr
|
||||
names := make(map[string]int)
|
||||
for i := range C.ggml_backend_dev_count() {
|
||||
r := C.ggml_backend_dev_backend_reg(C.ggml_backend_dev_get(i))
|
||||
|
||||
func() {
|
||||
fName := C.CString("ggml_backend_get_features")
|
||||
defer C.free(unsafe.Pointer(fName))
|
||||
|
||||
if fn := C.ggml_backend_reg_get_proc_address(r, fName); fn != nil {
|
||||
var features []any
|
||||
for f := C.first_feature(C.ggml_backend_get_features_t(fn), r); f.name != nil; f = C.next_feature(f) {
|
||||
features = append(features, C.GoString(f.name), C.GoString(f.value))
|
||||
}
|
||||
|
||||
name := C.GoString(C.ggml_backend_reg_name(r))
|
||||
attrs = append(attrs, slog.Group(name+"."+strconv.Itoa(names[name]), features...))
|
||||
names[name] += 1
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
switch C.compiler_name() {
|
||||
case C.COMPILER_CLANG:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(clang)"))
|
||||
case C.COMPILER_GNUC:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(gcc)"))
|
||||
default:
|
||||
attrs = append(attrs, slog.String("compiler", "cgo(unknown)"))
|
||||
}
|
||||
|
||||
return slog.GroupValue(attrs...)
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package nn
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,40 +12,50 @@ import (
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||
// - mask: Optional attention mask that is added to the attention score. If
|
||||
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||
var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||
}
|
||||
|
||||
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
_ "golang.org/x/image/tiff"
|
||||
_ "golang.org/x/image/webp"
|
||||
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
_ "github.com/ollama/ollama/ml/backend"
|
||||
@@ -100,6 +101,36 @@ func New(modelPath string, params ml.BackendParams) (Model, error) {
|
||||
return m, nil
|
||||
}
|
||||
|
||||
func NewTextProcessor(s string) (TextProcessor, error) {
|
||||
r, err := os.Open(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer r.Close()
|
||||
meta, _, err := fs.Decode(r, -1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return getTextProcessor(meta.KV())
|
||||
}
|
||||
|
||||
func getTextProcessor(kv fs.KV) (TextProcessor, error) {
|
||||
arch := kv.Architecture()
|
||||
f, ok := models[arch]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unsupported model architecture %q", arch)
|
||||
}
|
||||
m, err := f(kv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tp, ok := m.(TextProcessor)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%v is not a TextProcessor", m)
|
||||
}
|
||||
return tp, nil
|
||||
}
|
||||
|
||||
func populateFields(base Base, v reflect.Value, tags ...Tag) reflect.Value {
|
||||
t := v.Type()
|
||||
|
||||
@@ -248,8 +279,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -3,9 +3,11 @@ package model
|
||||
import (
|
||||
"reflect"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
fs "github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/backend/ggml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -134,3 +136,40 @@ func TestPopulateFieldsAlternateName(t *testing.T) {
|
||||
t.Errorf("populateFields() set incorrect values (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetTextProcessor(t *testing.T) {
|
||||
tp, err := getTextProcessor(fs.KV{})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "unsupported model architecture") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
|
||||
models["dummy"] = func(ml.Config) (Model, error) {
|
||||
return notTextProcessorModel{}, nil
|
||||
}
|
||||
tp, err = getTextProcessor(fs.KV{"general.architecture": "dummy"})
|
||||
if err == nil {
|
||||
t.Error("expected error")
|
||||
} else if !strings.Contains(err.Error(), "not a TextProcessor") {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
} else if tp != nil {
|
||||
t.Error("expected nil tp")
|
||||
}
|
||||
}
|
||||
|
||||
type notTextProcessorModel struct{}
|
||||
|
||||
func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Backend() ml.Backend {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
func (notTextProcessorModel) Config() config {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
@@ -29,6 +31,10 @@ type Model struct {
|
||||
}
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
if !strings.EqualFold(c.String("tokenizer.ggml.model"), "gpt2") {
|
||||
return nil, fmt.Errorf("tokenizer %s not yet supported", c.String("tokenizer.ggml.model"))
|
||||
}
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
@@ -37,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
@@ -79,15 +87,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package mllama
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/ml/nn"
|
||||
@@ -25,6 +27,10 @@ const (
|
||||
)
|
||||
|
||||
func New(c ml.Config) (model.Model, error) {
|
||||
// Verify unified config
|
||||
if c.Uint("vision.block_count") == 0 {
|
||||
return nil, fmt.Errorf("non-unified vision model not supported")
|
||||
}
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
@@ -33,7 +39,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
@@ -41,7 +49,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
encoderCache := kvcache.NewEncoderCache()
|
||||
encoderCache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
// This will only get called for layers in the causal cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value, mask ml.Tensor
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
key, value, _ = cache.Get(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scaleFactor)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
||||
@@ -30,7 +30,8 @@ type Vocabulary struct {
|
||||
Scores []uint32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS int32
|
||||
BOS, EOS int32
|
||||
AddBOS, AddEOS bool
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
@@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
if bpe.vocab.AddBOS {
|
||||
if ids[0] == bpe.vocab.BOS {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
|
||||
ids = append([]int32{bpe.vocab.BOS}, ids...)
|
||||
}
|
||||
|
||||
if bpe.vocab.AddEOS {
|
||||
if ids[len(ids)-1] == bpe.vocab.EOS {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
|
||||
ids = append(ids, bpe.vocab.EOS)
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -915,7 +915,6 @@ func Execute(args []string) error {
|
||||
level := slog.LevelInfo
|
||||
if *verbose {
|
||||
level = slog.LevelDebug
|
||||
llama.EnableDebug()
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
@@ -932,7 +931,6 @@ func Execute(args []string) error {
|
||||
slog.Info("starting go runner")
|
||||
|
||||
llama.BackendInit()
|
||||
slog.Info("system", "info", llama.PrintSystemInfo(), "threads", *threads)
|
||||
|
||||
server := &Server{
|
||||
batchSize: *batchSize,
|
||||
@@ -944,12 +942,11 @@ func Execute(args []string) error {
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||
for _, s := range stringFloats {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -970,13 +967,14 @@ func Execute(args []string) error {
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
fmt.Println("Listen error:", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
@@ -996,6 +994,5 @@ func Execute(args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -428,7 +428,8 @@ func (s *Server) processBatch() error {
|
||||
|
||||
// sample a token
|
||||
vocabSize := len(logits) / len(options.Outputs)
|
||||
|
||||
// TODO: need access to vocab to apply grammar
|
||||
// token = sampler.Grammar.Apply(logits)
|
||||
token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize])
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to sample token: %w", err)
|
||||
@@ -575,23 +576,18 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sampler, err := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
// TODO: if grammar is provided, load it
|
||||
// if req.Grammar != "" {
|
||||
// grammar := llama.NewGrammarWithTokens(req.Grammar, "root", s.model.Vocabulary)
|
||||
// }
|
||||
// defer grammar.Close()
|
||||
// sampler := sample.WithGrammar(sample.Greedy(), grammar)
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
sampler: sampler,
|
||||
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -798,8 +794,6 @@ func (s *Server) loadModel(
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Info("system", "info", s.model.Backend().SystemInfo(), "threads", params.NumThreads)
|
||||
|
||||
// TODO(jessegross): LoRA loading
|
||||
if lpath.String() != "" {
|
||||
panic("loras are not yet implemented")
|
||||
@@ -830,7 +824,7 @@ func Execute(args []string) error {
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
@@ -875,26 +869,25 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// flash-attn
|
||||
// no-mmap
|
||||
// mlock
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||
for _, s := range stringFloats {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
@@ -903,13 +896,14 @@ func Execute(args []string) error {
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
fmt.Println("Listen error:", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
@@ -929,6 +923,5 @@ func Execute(args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
@@ -54,53 +55,54 @@ func (s weighted) Sample(logits []float32) (int32, error) {
|
||||
if idx, ok := w.Take(); ok {
|
||||
return int32(indices[idx]), nil
|
||||
}
|
||||
return -1, errors.New("weighed sampler failed, no valid token found")
|
||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||
}
|
||||
|
||||
type greedy struct {
|
||||
transforms []Transform
|
||||
grammar llama.Grammar
|
||||
}
|
||||
|
||||
func Greedy(transforms ...Transform) Sampler {
|
||||
return greedy{transforms: transforms}
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func WithGrammar(s Sampler, grammar llama.Grammar) Sampler {
|
||||
switch t := s.(type) {
|
||||
case greedy:
|
||||
t.grammar = grammar
|
||||
return t
|
||||
default:
|
||||
return s
|
||||
}
|
||||
}
|
||||
|
||||
// Sample returns the index of the maximum value in logits.
|
||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("no logits provided for greedy sampling")
|
||||
}
|
||||
|
||||
for _, t := range s.transforms {
|
||||
logits64 = t.Apply(logits64)
|
||||
}
|
||||
|
||||
var maxIdx int
|
||||
var maxLogit float64
|
||||
for i, logit := range logits64 {
|
||||
if logit > maxLogit {
|
||||
maxLogit = logit
|
||||
maxIdx := 0
|
||||
for i := range logits {
|
||||
if logits[i] > logits[maxIdx] {
|
||||
maxIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if maxLogit == math.Inf(-1) {
|
||||
return -1, errors.New("no valid logits found for greedy sampling")
|
||||
}
|
||||
|
||||
return int32(maxIdx), nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||
transforms := []Transform{}
|
||||
if temperature == 0 {
|
||||
return Greedy(), nil
|
||||
}
|
||||
|
||||
if temperature < 0 || temperature > 2 {
|
||||
return nil, errors.New("temperature must be between 0 and 2")
|
||||
}
|
||||
|
||||
if temperature != 0 {
|
||||
transforms = append(transforms, Temperature(temperature))
|
||||
}
|
||||
transforms := []Transform{Temperature(temperature)}
|
||||
|
||||
if topK != 0 {
|
||||
if topK <= 0 {
|
||||
@@ -123,15 +125,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
transforms = append(transforms, MinP(minP))
|
||||
}
|
||||
|
||||
if len(transforms) == 0 {
|
||||
return nil, errors.New("at least one transform is required")
|
||||
}
|
||||
|
||||
if temperature == 0 {
|
||||
return Greedy(transforms...), nil
|
||||
}
|
||||
|
||||
if seed != 0 {
|
||||
if seed >= 0 {
|
||||
seed64 := uint64(seed)
|
||||
return Weighted(&seed64, transforms...), nil
|
||||
}
|
||||
|
||||
@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
|
||||
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
||||
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
want := int32(3) // Greedy sampler should pick highest logit
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
callOrder = nil
|
||||
|
||||
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder = []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
@@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no transforms",
|
||||
wantErr: true,
|
||||
name: "no transforms",
|
||||
// temperature is 0, so greedy should be used
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "temperature",
|
||||
@@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
wantErr: false,
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
wantErr: true,
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
wantErr: false,
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
wantErr: true,
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
wantErr: true,
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
wantErr: false,
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
wantErr: true,
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "seed",
|
||||
seed: 42,
|
||||
wantErr: true, // seed alone is not valid without other transforms
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: true, // all zeroes means no transforms
|
||||
wantErr: false, // all zeroes means no transforms
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
|
||||
}
|
||||
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(transforms...),
|
||||
"Greedy": Greedy(),
|
||||
"Weighted": Weighted(nil, transforms...),
|
||||
}
|
||||
|
||||
|
||||
@@ -77,11 +77,12 @@ if [ -d "$OLLAMA_INSTALL_DIR/lib/ollama" ] ; then
|
||||
fi
|
||||
status "Installing ollama to $OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d $BINDIR
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR"
|
||||
$SUDO install -o0 -g0 -m755 -d "$OLLAMA_INSTALL_DIR/lib/ollama"
|
||||
status "Downloading Linux ${ARCH} bundle"
|
||||
curl --fail --show-error --location --progress-bar \
|
||||
"https://ollama.com/download/ollama-linux-${ARCH}.tgz${VER_PARAM}" | \
|
||||
$SUDO tar -xzf - -C "$OLLAMA_INSTALL_DIR"
|
||||
|
||||
if [ "$OLLAMA_INSTALL_DIR/bin/ollama" != "$BINDIR/ollama" ] ; then
|
||||
status "Making ollama accessible in the PATH in $BINDIR"
|
||||
$SUDO ln -sf "$OLLAMA_INSTALL_DIR/ollama" "$BINDIR/ollama"
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -34,6 +35,7 @@ var (
|
||||
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
||||
errUnknownType = errors.New("unknown type")
|
||||
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
||||
errFilePath = errors.New("file path must be relative")
|
||||
)
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
@@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
@@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if r.Adapters != nil {
|
||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||
if errors.Is(err, badReq) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
// Set up a root to validate paths
|
||||
root, err := os.OpenRoot(tmpDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
for fp, digest := range files {
|
||||
if !fs.ValidPath(fp) {
|
||||
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
|
||||
}
|
||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
// Path is likely outside the root
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
f, _, err := ggml.Decode(bin, 0)
|
||||
if err != nil {
|
||||
|
||||
106
server/create_test.go
Normal file
106
server/create_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
return l.Digest
|
||||
}
|
||||
|
||||
// Create a safetensors compatible file with empty JSON content
|
||||
var buf bytes.Buffer
|
||||
headerSize := int64(len("{}"))
|
||||
binary.Write(&buf, binary.LittleEndian, headerSize)
|
||||
buf.WriteString("{}")
|
||||
|
||||
model := makeTemp(buf.String())
|
||||
config := makeTemp(`{
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"vocab_size": 32000
|
||||
}`)
|
||||
tokenizer := makeTemp(`{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filePath string
|
||||
wantErr error
|
||||
}{
|
||||
// Invalid
|
||||
{
|
||||
name: "InvalidRelativePathShallow",
|
||||
filePath: filepath.Join("..", "file.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidRelativePathDeep",
|
||||
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidNestedPath",
|
||||
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "AbsolutePathOutsideRoot",
|
||||
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
||||
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
||||
},
|
||||
{
|
||||
name: "ValidRelativePath",
|
||||
filePath: "model.safetensors",
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create the minimum required file map for convertFromSafetensors
|
||||
files := map[string]string{
|
||||
tt.filePath: model,
|
||||
"config.json": config,
|
||||
"tokenizer.json": tokenizer,
|
||||
}
|
||||
|
||||
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
||||
|
||||
if (tt.wantErr == nil && err != nil) ||
|
||||
(tt.wantErr != nil && err == nil) ||
|
||||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
||||
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
server/internal/cache/blob/cache.go
vendored
28
server/internal/cache/blob/cache.go
vendored
@@ -279,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
// TODO(bmizerany): Move link handling from cache to registry.
|
||||
//
|
||||
// We originally placed links in the cache due to its storage
|
||||
// knowledge. However, the registry likely offers better context for
|
||||
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||
// our on-disk format.
|
||||
//
|
||||
// Links work effectively when independent from physical location -
|
||||
// they can reference content with matching SHA regardless of storage
|
||||
// location. In an upcoming change, we plan to shift this
|
||||
// responsibility to the registry where it better aligns with the
|
||||
// system's conceptual model.
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -304,21 +316,19 @@ func (c *DiskCache) Link(name string, d Digest) error {
|
||||
return c.copyNamedFile(manifest, f, d, info.Size())
|
||||
}
|
||||
|
||||
// Unlink removes the any link for name. If the link does not exist, nothing
|
||||
// happens, and no error is returned.
|
||||
//
|
||||
// It returns an error if the name is invalid or if the link removal encounters
|
||||
// any issues.
|
||||
func (c *DiskCache) Unlink(name string) error {
|
||||
// Unlink unlinks the manifest by name from the cache. If the name is not
|
||||
// found. If a manifest is removed ok will be true, otherwise false. If an
|
||||
// error occurs, it returns ok false, and the error.
|
||||
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
err = os.Remove(manifest)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
return err
|
||||
return true, err
|
||||
}
|
||||
|
||||
// GetFile returns the absolute path to the file, in the cache, for the given
|
||||
|
||||
7
server/internal/cache/blob/cache_test.go
vendored
7
server/internal/cache/blob/cache_test.go
vendored
@@ -13,7 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/internal/testutil"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -479,8 +479,11 @@ func testManifestNameReuse(t *testing.T) {
|
||||
}
|
||||
|
||||
// relink with different case
|
||||
err = c.Unlink("h/n/m:t")
|
||||
unlinked, err := c.Unlink("h/n/m:t")
|
||||
check(err)
|
||||
if !unlinked {
|
||||
t.Fatal("expected unlinked")
|
||||
}
|
||||
err = c.Link("h/n/m:T", d1)
|
||||
check(err)
|
||||
|
||||
|
||||
2
server/internal/cache/blob/casecheck_test.go
vendored
2
server/internal/cache/blob/casecheck_test.go
vendored
@@ -86,7 +86,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
|
||||
// link to docs on that topic.
|
||||
lines := strings.Split(volumeHint, "\n")
|
||||
for _, line := range lines {
|
||||
t.Log(line)
|
||||
t.Skip(line)
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -19,12 +19,15 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@@ -52,7 +55,7 @@ var (
|
||||
|
||||
// ErrMissingModel is returned when the model part of a name is missing
|
||||
// or invalid.
|
||||
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
|
||||
ErrNameInvalid = errors.New("invalid or missing name")
|
||||
|
||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||
@@ -71,24 +74,41 @@ const (
|
||||
DefaultMaxChunkSize = 8 << 20
|
||||
)
|
||||
|
||||
// DefaultCache returns a new disk cache for storing models. If the
|
||||
// OLLAMA_MODELS environment variable is set, it uses that directory;
|
||||
// otherwise, it uses $HOME/.ollama/models.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
var defaultCache = sync.OnceValues(func() (*blob.DiskCache, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
home, _ := os.UserHomeDir()
|
||||
home = cmp.Or(home, ".")
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return blob.Open(dir)
|
||||
})
|
||||
|
||||
// DefaultCache returns the default cache used by the registry. It is
|
||||
// configured from the OLLAMA_MODELS environment variable, or defaults to
|
||||
// $HOME/.ollama/models, or, if an error occurs obtaining the home directory,
|
||||
// it uses the current working directory.
|
||||
func DefaultCache() (*blob.DiskCache, error) {
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
// Error is the standard error returned by Ollama APIs.
|
||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||
// single or multiple error response.
|
||||
//
|
||||
// Single error responses have the following format:
|
||||
//
|
||||
// {"code": "optional_code","error":"error message"}
|
||||
//
|
||||
// Multiple error responses have the following format:
|
||||
//
|
||||
// {"errors": [{"code": "optional_code","message":"error message"}]}
|
||||
//
|
||||
// Note, that the error field is used in single error responses, while the
|
||||
// message field is used in multiple error responses.
|
||||
//
|
||||
// In both cases, the code field is optional and may be empty.
|
||||
type Error struct {
|
||||
Status int `json:"-"`
|
||||
Status int `json:"-"` // TODO(bmizerany): remove this
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -97,13 +117,34 @@ func (e *Error) Error() string {
|
||||
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
|
||||
}
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Int("status", e.Status),
|
||||
slog.String("code", e.Code),
|
||||
slog.String("message", e.Message),
|
||||
)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (e *Error) UnmarshalJSON(b []byte) error {
|
||||
type E Error
|
||||
var v struct{ Errors []E }
|
||||
var v struct {
|
||||
// Single error
|
||||
Code string
|
||||
Error string
|
||||
|
||||
// Multiple errors
|
||||
Errors []E
|
||||
}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Error != "" {
|
||||
// Single error case
|
||||
e.Code = v.Code
|
||||
e.Message = v.Error
|
||||
return nil
|
||||
}
|
||||
if len(v.Errors) == 0 {
|
||||
return fmt.Errorf("no messages in error response: %s", string(b))
|
||||
}
|
||||
@@ -111,18 +152,30 @@ func (e *Error) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(bmizerany): make configurable on [Registry]
|
||||
var defaultName = func() names.Name {
|
||||
n := names.Parse("ollama.com/library/_:latest")
|
||||
const DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
|
||||
var defaultMask = func() names.Name {
|
||||
n := names.Parse(DefaultMask)
|
||||
if !n.IsFullyQualified() {
|
||||
panic("default name is not fully qualified")
|
||||
panic("default mask is not fully qualified")
|
||||
}
|
||||
return n
|
||||
}()
|
||||
|
||||
// CompleteName returns a fully qualified name by merging the given name with
|
||||
// the default mask. If the name is already fully qualified, it is returned
|
||||
// unchanged.
|
||||
func CompleteName(name string) string {
|
||||
return names.Merge(names.Parse(name), defaultMask).String()
|
||||
}
|
||||
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
// Cache is the cache used to store models. If nil, [DefaultCache] is
|
||||
// used.
|
||||
Cache *blob.DiskCache
|
||||
|
||||
// UserAgent is the User-Agent header to send with requests to the
|
||||
// registry. If empty, the User-Agent is determined by HTTPClient.
|
||||
UserAgent string
|
||||
@@ -160,21 +213,44 @@ type Registry struct {
|
||||
//
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified names
|
||||
// to fully qualified names. If empty, [DefaultMask] is used.
|
||||
Mask string
|
||||
}
|
||||
|
||||
// RegistryFromEnv returns a new Registry configured from the environment. The
|
||||
func (r *Registry) cache() (*blob.DiskCache, error) {
|
||||
if r.Cache != nil {
|
||||
return r.Cache, nil
|
||||
}
|
||||
return defaultCache()
|
||||
}
|
||||
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
mask = names.Parse(r.Mask)
|
||||
}
|
||||
n := names.Merge(names.Parse(name), mask)
|
||||
if !n.IsFullyQualified() {
|
||||
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
||||
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
|
||||
// system's temporary directory.
|
||||
//
|
||||
// It returns an error if any configuration in the environment is invalid.
|
||||
func RegistryFromEnv() (*Registry, error) {
|
||||
func DefaultRegistry() (*Registry, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
|
||||
if err != nil {
|
||||
if err != nil && errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -194,42 +270,6 @@ func RegistryFromEnv() (*Registry, error) {
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// parseName parses name using [names.ParseExtended] and then merges the name with the
|
||||
// default name, and checks that the name is fully qualified. If a digest is
|
||||
// present, it parse and returns it with the other fields as their zero values.
|
||||
//
|
||||
// It returns an error if the name is not fully qualified, or if the digest, if
|
||||
// any, is invalid.
|
||||
//
|
||||
// The scheme is returned as provided by [names.ParseExtended].
|
||||
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
|
||||
scheme, n, ds := names.ParseExtended(s)
|
||||
n = names.Merge(n, defaultName)
|
||||
if ds != "" {
|
||||
// Digest is present. Validate it.
|
||||
d, err = blob.ParseDigest(ds)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// The name check is deferred until after the digest check because we
|
||||
// say that digests take precedence over names, and so should there
|
||||
// errors when being parsed.
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
|
||||
}
|
||||
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
|
||||
@@ -249,13 +289,24 @@ func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
func (r *Registry) Push(ctx context.Context, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
p = &PushParams{}
|
||||
}
|
||||
|
||||
m, err := ResolveLocal(c, cmp.Or(p.From, name))
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m, err := r.ResolveLocal(cmp.Or(p.From, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -278,7 +329,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
scheme, n, _, err := parseName(name)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
// This should never happen since ResolveLocal should have
|
||||
// already validated the name.
|
||||
@@ -371,8 +422,8 @@ func canRetry(err error) bool {
|
||||
// chunks of the specified size, and then reassembled and verified. This is
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
scheme, n, _, err := parseName(name)
|
||||
func (r *Registry) Pull(ctx context.Context, name string) error {
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -385,6 +436,11 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return fmt.Errorf("%w: no layers", ErrManifestInvalid)
|
||||
}
|
||||
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
exists := func(l *Layer) bool {
|
||||
info, err := c.Get(l.Digest)
|
||||
return err == nil && info.Size == l.Size
|
||||
@@ -520,6 +576,20 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return c.Link(m.Name, md)
|
||||
}
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
|
||||
// Manifest represents a [ollama.com/manifest].
|
||||
type Manifest struct {
|
||||
Name string `json:"-"` // the canonical name of the model
|
||||
@@ -588,14 +658,18 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||
// parsed using [names.ParseExtended] but the scheme is ignored.
|
||||
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := parseName(name)
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(name string) (*Manifest, error) {
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c, err := r.cache()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !d.IsValid() {
|
||||
// No digest, so resolve the manifest by name.
|
||||
d, err = c.Resolve(n.String())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -617,7 +691,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
|
||||
// Resolve resolves a name to a Manifest in the remote registry.
|
||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||
scheme, n, d, err := parseName(name)
|
||||
scheme, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -800,3 +874,89 @@ func maybeUnexpectedEOF(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
}
|
||||
|
||||
func withPublicMessagef(err error, message string, args ...any) error {
|
||||
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
|
||||
}
|
||||
|
||||
func (e publicError) Error() string { return e.message }
|
||||
func (e publicError) Unwrap() error { return e.wrapped }
|
||||
|
||||
var supportedSchemes = []string{
|
||||
"http",
|
||||
"https",
|
||||
"https+insecure",
|
||||
}
|
||||
|
||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||
|
||||
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
||||
// and digest.
|
||||
//
|
||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
//
|
||||
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
|
||||
// message is returned.
|
||||
//
|
||||
// If the name is not, once merged with the mask, fully qualified,
|
||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := splitExtended(s)
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
if !slices.Contains(supportedSchemes, scheme) {
|
||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
|
||||
var d blob.Digest
|
||||
if digest != "" {
|
||||
var err error
|
||||
d, err = blob.ParseDigest(digest)
|
||||
if err != nil {
|
||||
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
if name == "" {
|
||||
// We have can resolve a manifest from a digest only,
|
||||
// so skip name validation and return the scheme and
|
||||
// digest.
|
||||
return scheme, names.Name{}, d, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
// splitExtended splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
// model@digest
|
||||
// @digest
|
||||
func splitExtended(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/internal/testutil"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func TestManifestMarshalJSON(t *testing.T) {
|
||||
@@ -37,20 +38,6 @@ func TestManifestMarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func link(c *blob.DiskCache, name string, manifest string) {
|
||||
_, n, _, err := parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var errRoundTrip = errors.New("forced roundtrip error")
|
||||
|
||||
type recordRoundTripper http.HandlerFunc
|
||||
@@ -86,6 +73,7 @@ func (rr recordRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
||||
// To simulate a network error, pass a handler that returns a 499 status code.
|
||||
func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
t.Helper()
|
||||
|
||||
c, err := blob.Open(t.TempDir())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -98,30 +86,46 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
Cache: c,
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
commit := func(name string, layers ...*Layer) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(&Manifest{Layers: layers})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link(c, name, string(data))
|
||||
link(name, string(data))
|
||||
}
|
||||
|
||||
link(c, "empty", "")
|
||||
link("empty", "")
|
||||
commit("zero")
|
||||
commit("single", mklayer("exists"))
|
||||
commit("multiple", mklayer("exists"), mklayer("present"))
|
||||
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
||||
commit("null", nil)
|
||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||
link(c, "invalid", "!!!!!")
|
||||
link("invalid", "!!!!!")
|
||||
|
||||
rc := &Registry{
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
return rc, c
|
||||
return r, c
|
||||
}
|
||||
|
||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -144,84 +148,61 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
||||
return d
|
||||
}
|
||||
|
||||
func TestRegistryPushInvalidNames(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"", ErrNameInvalid},
|
||||
{"@", ErrNameInvalid},
|
||||
{"@x", blob.ErrInvalidDigest},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new registry and push a new image.
|
||||
err := rc.Push(t.Context(), c, tt.name, nil)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("err = %v; want %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||
return WithTrace(ctx, t), t
|
||||
}
|
||||
|
||||
func TestPushZero(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "empty", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "empty", nil)
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrManifestInvalid)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSingle(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "single", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "single", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushMultiple(t *testing.T) {
|
||||
rc, c := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), c, "multiple", nil)
|
||||
rc, _ := newClient(t, okHandler)
|
||||
err := rc.Push(t.Context(), "multiple", nil)
|
||||
testutil.Check(t, err)
|
||||
}
|
||||
|
||||
func TestPushNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Errorf("unexpected request: %v", r)
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "notfound", nil)
|
||||
err := rc.Push(t.Context(), "notfound", nil)
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want %v", err, fs.ErrNotExist)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushNullLayer(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "null", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "null", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushSizeMismatch(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
ctx, _ := withTraceUnexpected(t.Context())
|
||||
got := rc.Push(ctx, c, "sizemismatch", nil)
|
||||
got := rc.Push(ctx, "sizemismatch", nil)
|
||||
if got == nil || !strings.Contains(got.Error(), "size mismatch") {
|
||||
t.Errorf("err = %v; want size mismatch", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushInvalid(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), c, "invalid", nil)
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Push(t.Context(), "invalid", nil)
|
||||
if err == nil || !strings.Contains(err.Error(), "invalid manifest") {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@@ -229,7 +210,7 @@ func TestPushInvalid(t *testing.T) {
|
||||
|
||||
func TestPushExistsAtRemote(t *testing.T) {
|
||||
var pushed bool
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/uploads/") {
|
||||
if !pushed {
|
||||
// First push. Return an uploadURL.
|
||||
@@ -257,35 +238,35 @@ func TestPushExistsAtRemote(t *testing.T) {
|
||||
|
||||
check := testutil.Checker(t)
|
||||
|
||||
err := rc.Push(ctx, c, "single", nil)
|
||||
err := rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
|
||||
if !errors.Is(errors.Join(errs...), nil) {
|
||||
t.Errorf("errs = %v; want %v", errs, []error{ErrCached})
|
||||
}
|
||||
|
||||
err = rc.Push(ctx, c, "single", nil)
|
||||
err = rc.Push(ctx, "single", nil)
|
||||
check(err)
|
||||
}
|
||||
|
||||
func TestPushRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(500)
|
||||
io.WriteString(w, `{"errors":[{"code":"blob_error"}]}`)
|
||||
return
|
||||
}
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
checkErrCode(t, got, 500, "blob_error")
|
||||
}
|
||||
|
||||
func TestPushLocationError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Location", ":///x")
|
||||
w.WriteHeader(http.StatusAccepted)
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
wantContains := "invalid upload URL"
|
||||
if got == nil || !strings.Contains(got.Error(), wantContains) {
|
||||
t.Errorf("err = %v; want to contain %v", got, wantContains)
|
||||
@@ -293,14 +274,14 @@ func TestPushLocationError(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestPushUploadRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Host == "blob.store" {
|
||||
w.WriteHeader(499) // force RoundTrip error on upload
|
||||
return
|
||||
}
|
||||
w.Header().Set("Location", "http://blob.store/blobs/123")
|
||||
})
|
||||
got := rc.Push(t.Context(), c, "single", nil)
|
||||
got := rc.Push(t.Context(), "single", nil)
|
||||
if !errors.Is(got, errRoundTrip) {
|
||||
t.Errorf("got = %v; want %v", got, errRoundTrip)
|
||||
}
|
||||
@@ -316,20 +297,20 @@ func TestPushUploadFileOpenError(t *testing.T) {
|
||||
os.Remove(c.GetFile(l.Digest))
|
||||
},
|
||||
})
|
||||
got := rc.Push(ctx, c, "single", nil)
|
||||
got := rc.Push(ctx, "single", nil)
|
||||
if !errors.Is(got, fs.ErrNotExist) {
|
||||
t.Errorf("got = %v; want fs.ErrNotExist", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPushCommitRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
panic("unexpected")
|
||||
}
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
})
|
||||
err := rc.Push(t.Context(), c, "zero", nil)
|
||||
err := rc.Push(t.Context(), "zero", nil)
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@@ -343,8 +324,8 @@ func checkNotExist(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func TestRegistryPullInvalidName(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), c, "://")
|
||||
rc, _ := newClient(t, nil)
|
||||
err := rc.Pull(t.Context(), "://")
|
||||
if !errors.Is(err, ErrNameInvalid) {
|
||||
t.Errorf("err = %v; want %v", err, ErrNameInvalid)
|
||||
}
|
||||
@@ -359,10 +340,10 @@ func TestRegistryPullInvalidManifest(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, resp := range cases {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
io.WriteString(w, resp)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "x")
|
||||
err := rc.Pull(t.Context(), "x")
|
||||
if !errors.Is(err, ErrManifestInvalid) {
|
||||
t.Errorf("err = %v; want invalid manifest", err)
|
||||
}
|
||||
@@ -385,18 +366,18 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := ResolveLocal(c, "model")
|
||||
_, err := rc.ResolveLocal("model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
checkNotExist(t, err)
|
||||
|
||||
err = rc.Pull(t.Context(), c, "model")
|
||||
err = rc.Pull(t.Context(), "model")
|
||||
check(err)
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := ResolveLocal(c, "model")
|
||||
mg, err := rc.ResolveLocal("model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
@@ -421,7 +402,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
|
||||
func TestRegistryPullCached(t *testing.T) {
|
||||
cached := blob.DigestFromBytes("exists")
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/blobs/") {
|
||||
w.WriteHeader(499) // should not be called
|
||||
return
|
||||
@@ -444,7 +425,7 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
err := rc.Pull(ctx, c, "single")
|
||||
err := rc.Pull(ctx, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{6}
|
||||
@@ -457,30 +438,30 @@ func TestRegistryPullCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullManifestNotFound(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "notfound")
|
||||
err := rc.Pull(t.Context(), "notfound")
|
||||
checkErrCode(t, err, 404, "")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRemoteError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
io.WriteString(w, `{"errors":[{"code":"an_error"}]}`)
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
checkErrCode(t, err, 500, "an_error")
|
||||
}
|
||||
|
||||
func TestRegistryPullResolveRoundtripError(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
if strings.Contains(r.URL.Path, "/manifests/") {
|
||||
w.WriteHeader(499) // force RoundTrip error
|
||||
return
|
||||
}
|
||||
})
|
||||
err := rc.Pull(t.Context(), c, "single")
|
||||
err := rc.Pull(t.Context(), "single")
|
||||
if !errors.Is(err, errRoundTrip) {
|
||||
t.Errorf("err = %v; want %v", err, errRoundTrip)
|
||||
}
|
||||
@@ -533,7 +514,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
|
||||
// Check that we pull all layers that we can.
|
||||
|
||||
err := rc.Pull(ctx, c, "mixed")
|
||||
err := rc.Pull(ctx, "mixed")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@@ -551,7 +532,7 @@ func TestRegistryPullMixedCachedNotCached(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestRegistryPullChunking(t *testing.T) {
|
||||
rc, c := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
rc, _ := newClient(t, func(w http.ResponseWriter, r *http.Request) {
|
||||
t.Log("request:", r.URL.Host, r.Method, r.URL.Path, r.Header.Get("Range"))
|
||||
if r.URL.Host != "blob.store" {
|
||||
// The production registry redirects to the blob store.
|
||||
@@ -589,7 +570,7 @@ func TestRegistryPullChunking(t *testing.T) {
|
||||
},
|
||||
})
|
||||
|
||||
err := rc.Pull(ctx, c, "remote")
|
||||
err := rc.Pull(ctx, "remote")
|
||||
testutil.Check(t, err)
|
||||
|
||||
want := []int64{0, 3, 6}
|
||||
@@ -621,7 +602,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
const name = "ollama.com/library/insecure"
|
||||
const name = "library/insecure"
|
||||
|
||||
var rc Registry
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
@@ -654,3 +635,184 @@ func TestCanRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
want *Error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors single",
|
||||
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "errors multiple",
|
||||
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "error empty",
|
||||
data: `{"error":""}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error very empty",
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
data: `{"error":"message", "code":"code"}`,
|
||||
want: &Error{Code: "code", Message: "message"},
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
data: `{"error": 1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Error
|
||||
err := json.Unmarshal([]byte(tt.data), &got)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Errorf("Unmarshal() error = %v", err)
|
||||
// fallthrough and check got
|
||||
}
|
||||
if tt.want == nil {
|
||||
tt.want = &Error{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, *tt.want) {
|
||||
t.Errorf("got = %v; want %v", got, *tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseNameErrors tests that parseName returns errors messages with enough
|
||||
// detail for users to debug naming issues they may encounter. Previous to this
|
||||
// test, the error messages were not very helpful and each problem was reported
|
||||
// as the same message.
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal("single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink("single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal("single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, _ := newClient(t, nil)
|
||||
ok, err := rc.Unlink("manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -6,6 +6,9 @@ import (
|
||||
|
||||
// Trace is a set of functions that are called to report progress during blob
|
||||
// downloads and uploads.
|
||||
//
|
||||
// Use [WithTrace] to attach a Trace to a context for use with [Registry.Push]
|
||||
// and [Registry.Pull].
|
||||
type Trace struct {
|
||||
// Update is called during [Registry.Push] and [Registry.Pull] to
|
||||
// report the progress of blob uploads and downloads.
|
||||
|
||||
@@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||
|
||||
// TODO(bmizerany): do something with metadata? This could be another
|
||||
// header read if needed. We also need to figure out if the metadata is
|
||||
// present in only one .safetensors file or if each file may have their
|
||||
@@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
tt := make([]*Tensor, 0, len(raws))
|
||||
for name, raw := range raws {
|
||||
if !strings.HasPrefix(name, "model.layer") {
|
||||
if name == "__metadata__" {
|
||||
// TODO(bmizerany): do something with metadata?
|
||||
continue
|
||||
}
|
||||
var v struct {
|
||||
@@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
// TODO(bmizerany): after collecting, validate all offests make
|
||||
// tensors contiguous?
|
||||
begin, end := v.Offsets[0], v.Offsets[1]
|
||||
begin := endOfHeader + v.Offsets[0]
|
||||
end := endOfHeader + v.Offsets[1]
|
||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -63,25 +63,28 @@ func main() {
|
||||
}
|
||||
flag.Parse()
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := ollama.RegistryFromEnv()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
err = func() error {
|
||||
err := func() error {
|
||||
switch cmd := flag.Arg(0); cmd {
|
||||
case "pull":
|
||||
return cmdPull(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
return cmdPull(ctx, rc)
|
||||
case "push":
|
||||
return cmdPush(ctx, rc, c)
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdPush(ctx, rc)
|
||||
case "import":
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
return cmdImport(ctx, c)
|
||||
default:
|
||||
if cmd == "" {
|
||||
@@ -99,7 +102,7 @@ func main() {
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPull(ctx context.Context, rc *ollama.Registry) error {
|
||||
model := flag.Arg(1)
|
||||
if model == "" {
|
||||
flag.Usage()
|
||||
@@ -145,7 +148,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
|
||||
errc := make(chan error)
|
||||
go func() {
|
||||
errc <- rc.Pull(ctx, c, model)
|
||||
errc <- rc.Pull(ctx, model)
|
||||
}()
|
||||
|
||||
t := time.NewTicker(time.Second)
|
||||
@@ -161,7 +164,7 @@ func cmdPull(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error {
|
||||
func cmdPush(ctx context.Context, rc *ollama.Registry) error {
|
||||
args := flag.Args()[1:]
|
||||
flag := flag.NewFlagSet("push", flag.ExitOnError)
|
||||
flagFrom := flag.String("from", "", "Use the manifest from a model by another name.")
|
||||
@@ -177,7 +180,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := ollama.ResolveLocal(c, from)
|
||||
m, err := rc.ResolveLocal(from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -203,7 +206,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
},
|
||||
})
|
||||
|
||||
return rc.Push(ctx, c, model, &ollama.PushParams{
|
||||
return rc.Push(ctx, model, &ollama.PushParams{
|
||||
From: from,
|
||||
})
|
||||
}
|
||||
@@ -228,6 +231,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
if *flagAs == "" {
|
||||
return fmt.Errorf("missing -as flag")
|
||||
}
|
||||
as := ollama.CompleteName(*flagAs)
|
||||
|
||||
dir := cmp.Or(flag.Arg(0), ".")
|
||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||
@@ -311,7 +318,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Link(*flagAs, d)
|
||||
return c.Link(as, d)
|
||||
}()
|
||||
}()
|
||||
|
||||
@@ -340,6 +347,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
writeProgress()
|
||||
case err := <-done:
|
||||
writeProgress()
|
||||
fmt.Println()
|
||||
fmt.Println("Successfully imported", as)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build goexperiment.synctest
|
||||
|
||||
package backoff
|
||||
|
||||
import (
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||
)
|
||||
|
||||
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
|
||||
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
|
||||
|
||||
type Name struct {
|
||||
// Make incomparable to enfoce use of Compare / Equal for
|
||||
@@ -25,19 +25,12 @@ type Name struct {
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
@@ -50,9 +43,6 @@ type Name struct {
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||
@@ -82,23 +72,17 @@ func Parse(s string) Name {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseExtended parses and returns any scheme, Name, and digest from from s in
|
||||
// the the form [scheme://][name][@digest]. All parts are optional.
|
||||
//
|
||||
// If the scheme is present, it must be followed by "://". The digest is
|
||||
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
|
||||
//
|
||||
// The scheme and digest are stripped before the name is parsed by [Parse].
|
||||
//
|
||||
// For convience, the scheme is never empty. If the scheme is not present, the
|
||||
// returned scheme is "https".
|
||||
// Split splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
// model@digest
|
||||
// @digest
|
||||
func Split(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, Parse(s), digest
|
||||
}
|
||||
|
||||
func FormatExtended(scheme string, n Name, digest string) string {
|
||||
var b strings.Builder
|
||||
if scheme != "" {
|
||||
b.WriteString(scheme)
|
||||
b.WriteString("://")
|
||||
}
|
||||
b.WriteString(n.String())
|
||||
if digest != "" {
|
||||
b.WriteByte('@')
|
||||
b.WriteString(digest)
|
||||
}
|
||||
return b.String()
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
|
||||
|
||||
// IsValid returns true if the name is valid.
|
||||
func (n Name) IsValid() bool {
|
||||
if n.h != "" && !isValidHost(n.h) {
|
||||
if n.h != "" && !isValidPart(partHost, n.h) {
|
||||
return false
|
||||
}
|
||||
if n.n != "" && !isValidNamespace(n.n) {
|
||||
if n.n != "" && !isValidPart(partNamespace, n.n) {
|
||||
return false
|
||||
}
|
||||
if n.m != "" && !isValidModel(n.m) {
|
||||
if n.t != "" && !isValidPart(partTag, n.t) {
|
||||
return false
|
||||
}
|
||||
if n.t != "" && !isValidTag(n.t) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
// at bare minimum, model must be present and valid
|
||||
return n.m != "" && isValidPart(partModel, n.m)
|
||||
}
|
||||
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||
}
|
||||
|
||||
func isValidHost(_ string) bool {
|
||||
return true // TODO: implement
|
||||
const (
|
||||
partHost = iota
|
||||
partNamespace
|
||||
partModel
|
||||
partTag
|
||||
)
|
||||
|
||||
func isValidPart(kind int, s string) bool {
|
||||
maxlen := 80
|
||||
if kind == partHost {
|
||||
maxlen = 350
|
||||
}
|
||||
if len(s) > maxlen {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case '_', '-':
|
||||
case '.':
|
||||
if kind == partNamespace {
|
||||
return false
|
||||
}
|
||||
case ':':
|
||||
if kind != partHost {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidNamespace(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidModel(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidTag(_ string) bool {
|
||||
return true // TODO: implement
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func (n Name) Host() string { return n.h }
|
||||
|
||||
@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
scheme, name, digest := ParseExtended(tt.in)
|
||||
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
scheme, name, digest := Split(tt.in)
|
||||
n := Parse(name)
|
||||
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||
}
|
||||
|
||||
// Round trip
|
||||
if got := FormatExtended(scheme, name, digest); got != tt.in {
|
||||
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
|
||||
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
|
||||
)
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": true,
|
||||
"namespace/model": true,
|
||||
"model": true,
|
||||
|
||||
// long (but valid)
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
|
||||
// too long
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
|
||||
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
|
||||
|
||||
"h/nn/mm:t": true, // bare minimum part sizes
|
||||
|
||||
// unqualified
|
||||
"m": true,
|
||||
"n/m:": true,
|
||||
"h/n/m": true,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
"mm:": true,
|
||||
"/nn/mm": true,
|
||||
"//": false, // empty model
|
||||
"//mm": true,
|
||||
"hh//": false, // empty model
|
||||
"//mm:@": false,
|
||||
"00@": false,
|
||||
"@": false,
|
||||
|
||||
// not starting with alphanum
|
||||
"-hh/nn/mm:tt": false,
|
||||
"hh/-nn/mm:tt": false,
|
||||
"hh/nn/-mm:tt": false,
|
||||
"hh/nn/mm:-tt": false,
|
||||
|
||||
// smells like a flag
|
||||
"-h": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
// colon in non-host part before tag
|
||||
"host/name:space/model:tag": false,
|
||||
}
|
||||
|
||||
func TestParseNameValidation(t *testing.T) {
|
||||
for s, valid := range testCases {
|
||||
got := Parse(s)
|
||||
if got.IsValid() != valid {
|
||||
t.Logf("got: %v", got)
|
||||
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
//go:build goexperiment.synctest
|
||||
|
||||
package syncs
|
||||
|
||||
import (
|
||||
|
||||
240
server/internal/registry/server.go
Normal file
240
server/internal/registry/server.go
Normal file
@@ -0,0 +1,240 @@
|
||||
// Package registry provides an http.Handler for handling local Ollama API
|
||||
// requests for performing tasks related to the ollama.com model registry and
|
||||
// the local disk cache.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
// Local is an http.Handler for handling local Ollama API requests for
|
||||
// performing tasks related to the ollama.com model registry combined with the
|
||||
// local disk cache.
|
||||
//
|
||||
// It is not concern of Local, or this package, to handle model creation, which
|
||||
// proceeds any registry operations for models it produces.
|
||||
//
|
||||
// NOTE: The package built for dealing with model creation should use
|
||||
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||
// directly to the blob disk cache.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
// this handler.
|
||||
Fallback http.Handler
|
||||
|
||||
// Prune, if set, is called to prune the local disk cache after a model
|
||||
// is deleted.
|
||||
Prune func() error // optional
|
||||
}
|
||||
|
||||
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||||
// response code. We want to avoid adding that field to ollama.Error because it
|
||||
// would always be 0 to clients (we don't want to leak the status code in
|
||||
// errors), and so it would be confusing to have a field that is always 0.
|
||||
type serverError struct {
|
||||
Status int `json:"-"`
|
||||
|
||||
// TODO(bmizerany): Decide if we want to keep this and maybe
|
||||
// bring back later.
|
||||
Code string `json:"code"`
|
||||
|
||||
Message string `json:"error"`
|
||||
}
|
||||
|
||||
func (e serverError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Common API errors
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
type statusCodeRecorder struct {
|
||||
_status int // use status() to get the status code
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) WriteHeader(status int) {
|
||||
if r._status == 0 {
|
||||
r._status = status
|
||||
}
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
|
||||
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
|
||||
_ http.Flusher = (*statusCodeRecorder)(nil)
|
||||
)
|
||||
|
||||
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a CloseNotifier.
|
||||
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
|
||||
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a Flusher.
|
||||
func (r *statusCodeRecorder) Flush() {
|
||||
r.ResponseWriter.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) status() int {
|
||||
return cmp.Or(r._status, 200)
|
||||
}
|
||||
|
||||
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rec := &statusCodeRecorder{ResponseWriter: w}
|
||||
s.serveHTTP(rec, r)
|
||||
}
|
||||
|
||||
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
var errattr slog.Attr
|
||||
proxied, err := func() (bool, error) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
return false, s.handleDelete(rec, r)
|
||||
default:
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
return true, nil
|
||||
}
|
||||
return false, errNotFound
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
// We always log the error, so fill in the error log attribute
|
||||
errattr = slog.String("error", err.Error())
|
||||
|
||||
var e *serverError
|
||||
switch {
|
||||
case errors.As(err, &e):
|
||||
case errors.Is(err, ollama.ErrNameInvalid):
|
||||
e = &serverError{400, "bad_request", err.Error()}
|
||||
default:
|
||||
e = errInternalError
|
||||
}
|
||||
|
||||
data, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
// unreachable
|
||||
panic(err)
|
||||
}
|
||||
rec.Header().Set("Content-Type", "application/json")
|
||||
rec.WriteHeader(e.Status)
|
||||
rec.Write(data)
|
||||
|
||||
// fallthrough to log
|
||||
}
|
||||
|
||||
if !proxied {
|
||||
// we're only responsible for logging if we handled the request
|
||||
var level slog.Level
|
||||
if rec.status() >= 500 {
|
||||
level = slog.LevelError
|
||||
} else if rec.status() >= 400 {
|
||||
level = slog.LevelWarn
|
||||
}
|
||||
|
||||
s.Logger.LogAttrs(r.Context(), level, "http",
|
||||
errattr, // report first in line to make it easy to find
|
||||
|
||||
// TODO(bmizerany): Write a test to ensure that we are logging
|
||||
// all of this correctly. That also goes for the level+error
|
||||
// logic above.
|
||||
slog.Int("status", rec.status()),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int64("content-length", r.ContentLength),
|
||||
slog.String("remote", r.RemoteAddr),
|
||||
slog.String("proto", r.Proto),
|
||||
slog.String("query", r.URL.RawQuery),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type params struct {
|
||||
DeprecatedName string `json:"name"` // Use [params.model]
|
||||
Model string `json:"model"` // Use [params.model]
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
//
|
||||
// Deprecated: This field is ignored and only present for this
|
||||
// deprecation message. It should be removed in a future release.
|
||||
//
|
||||
// Users can just use http or https+insecure to show intent to
|
||||
// communicate they want to do insecure things, without awkward and
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||
// progress updates.
|
||||
ProgressStream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
}
|
||||
if s.Prune == nil {
|
||||
return nil
|
||||
}
|
||||
return s.Prune()
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
var zero T
|
||||
|
||||
// Not sure why, but I can't seem to be able to use:
|
||||
//
|
||||
// errors.As(err, &json.UnmarshalTypeError{})
|
||||
//
|
||||
// This is working fine in stdlib, so I'm not sure what rules changed
|
||||
// and why this no longer works here. So, we do it the verbose way.
|
||||
var a *json.UnmarshalTypeError
|
||||
var b *json.SyntaxError
|
||||
if errors.As(err, &a) || errors.As(err, &b) {
|
||||
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
165
server/internal/registry/server_test.go
Normal file
165
server/internal/registry/server_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
panic("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
// bytesResetter is an interface for types that can be reset and return a byte
|
||||
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
|
||||
// etc for the purpose of checking logs.
|
||||
type bytesResetter interface {
|
||||
Bytes() []byte
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *Local {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := blob.Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc := &ollama.Registry{
|
||||
Cache: c,
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
l := &Local{
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
|
||||
return s.sendRequest(t, req)
|
||||
}
|
||||
|
||||
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
type invalidReader struct{}
|
||||
|
||||
func (r *invalidReader) Read(p []byte) (int, error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
// captureLogs is a helper to capture logs from the server. It returns a
|
||||
// shallow copy of the server with a new logger and a bytesResetter for the
|
||||
// logs.
|
||||
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
||||
t.Helper()
|
||||
log, logs := testutil.SlogBuffer()
|
||||
l := *s // shallow copy
|
||||
l.Logger = log
|
||||
return &l, logs
|
||||
}
|
||||
|
||||
func TestServerDelete(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
s := newTestServer(t)
|
||||
|
||||
_, err := s.Client.ResolveLocal("smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal("smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
|
||||
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
s, logs := captureLogs(t, s)
|
||||
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
|
||||
got = s.sendRequest(t, req)
|
||||
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
|
||||
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
|
||||
check(err)
|
||||
if !ok {
|
||||
t.Logf("logs:\n%s", logs)
|
||||
t.Fatalf("expected log to contain ERROR with invalid argument")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
t.Helper()
|
||||
|
||||
var printedBody bool
|
||||
errorf := func(format string, args ...any) {
|
||||
t.Helper()
|
||||
if !printedBody {
|
||||
t.Logf("BODY:\n%s", got.Body.String())
|
||||
printedBody = true
|
||||
}
|
||||
t.Errorf(format, args...)
|
||||
}
|
||||
|
||||
if got.Code != status {
|
||||
errorf("Code = %d; want %d", got.Code, status)
|
||||
}
|
||||
|
||||
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
|
||||
var e *ollama.Error
|
||||
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
|
||||
errorf("unmarshal error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
if e.Code != code {
|
||||
errorf("Code = %q; want %q", e.Code, code)
|
||||
}
|
||||
if !strings.Contains(e.Message, msg) {
|
||||
errorf("Message = %q; want to contain %q", e.Message, msg)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
1
server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
vendored
Normal file
1
server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
@@ -1,12 +1,40 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
|
||||
func LogWriter(t *testing.T) io.Writer {
|
||||
return testWriter{t}
|
||||
}
|
||||
|
||||
type testWriter struct{ t *testing.T }
|
||||
|
||||
func (w testWriter) Write(b []byte) (int, error) {
|
||||
w.t.Logf("%s", b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Slogger returns a [*slog.Logger] that writes each message
|
||||
// using t.Log.
|
||||
func Slogger(t *testing.T) *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
|
||||
}
|
||||
|
||||
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
|
||||
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
|
||||
var buf bytes.Buffer
|
||||
lg = slog.New(slog.NewTextHandler(&buf, nil))
|
||||
return lg, &buf
|
||||
}
|
||||
|
||||
// Check calls t.Fatal(err) if err is not nil.
|
||||
func Check(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/envconfig"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/template"
|
||||
@@ -93,7 +92,7 @@ func chatPrompt(ctx context.Context, m *Model, tokenize tokenizeFunc, opts *api.
|
||||
var imgData llm.ImageData
|
||||
|
||||
if isMllama {
|
||||
if envconfig.NewEngine() {
|
||||
if len(m.ProjectorPaths) == 0 {
|
||||
imgData = llm.ImageData{
|
||||
ID: len(images),
|
||||
Data: i,
|
||||
|
||||
@@ -34,6 +34,8 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -203,7 +205,7 @@ func (s *Server) GenerateHandler(c *gin.Context) {
|
||||
|
||||
images := make([]llm.ImageData, len(req.Images))
|
||||
for i := range req.Images {
|
||||
if isMllama && !envconfig.NewEngine() {
|
||||
if isMllama && len(model.ProjectorPaths) > 0 {
|
||||
data, opts, err := mllama.Preprocess(bytes.NewReader(req.Images[i]))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "error processing image"})
|
||||
@@ -1126,7 +1128,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
func (s *Server) GenerateRoutes(rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
@@ -1165,10 +1167,9 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
r.GET("/api/tags", s.ListHandler)
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
@@ -1193,7 +1194,16 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
return r
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
|
||||
Prune: PruneLayers,
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
@@ -1246,12 +1256,23 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Handle("/", h)
|
||||
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
schedCtx, schedDone := context.WithCancel(ctx)
|
||||
sched := InitScheduler(schedCtx)
|
||||
s := &Server{addr: ln.Addr(), sched: sched}
|
||||
|
||||
http.Handle("/", s.GenerateRoutes())
|
||||
s.sched = sched
|
||||
|
||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||
srvr := &http.Server{
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -91,7 +92,15 @@ func equalStringSlices(a, b []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func Test_Routes(t *testing.T) {
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
panic("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
@@ -241,10 +250,10 @@ func Test_Routes(t *testing.T) {
|
||||
Method: http.MethodDelete,
|
||||
Path: "/api/delete",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
createTestModel(t, "model-to-delete")
|
||||
createTestModel(t, "model_to_delete")
|
||||
|
||||
deleteReq := api.DeleteRequest{
|
||||
Name: "model-to-delete",
|
||||
Name: "model_to_delete",
|
||||
}
|
||||
jsonData, err := json.Marshal(deleteReq)
|
||||
if err != nil {
|
||||
@@ -271,7 +280,7 @@ func Test_Routes(t *testing.T) {
|
||||
Path: "/api/delete",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
deleteReq := api.DeleteRequest{
|
||||
Name: "non-existent-model",
|
||||
Name: "non_existent_model",
|
||||
}
|
||||
jsonData, err := json.Marshal(deleteReq)
|
||||
if err != nil {
|
||||
@@ -477,10 +486,29 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
rc := &ollama.Registry{
|
||||
// This is a temporary measure to allow us to move forward,
|
||||
// surfacing any code contacting ollama.com we do not intended
|
||||
// to.
|
||||
//
|
||||
// Currently, this only handles DELETE /api/delete, which
|
||||
// should not make any contact with the ollama.com registry, so
|
||||
// be clear about that.
|
||||
//
|
||||
// Tests that do need to contact the registry here, will be
|
||||
// consumed into our new server/api code packages and removed
|
||||
// from here.
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
|
||||
s := &Server{}
|
||||
router := s.GenerateRoutes()
|
||||
router, err := s.GenerateRoutes(rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
}
|
||||
|
||||
httpSrv := httptest.NewServer(router)
|
||||
t.Cleanup(httpSrv.Close)
|
||||
|
||||
Reference in New Issue
Block a user