Compare commits
31 Commits
v0.5.13-rc
...
v0.5.13-rc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e41c4cbea7 | ||
|
|
ee048b76d4 | ||
|
|
af68d60a58 | ||
|
|
21aa666a1e | ||
|
|
ee141cc821 | ||
|
|
55e5776c44 | ||
|
|
854a9195f3 | ||
|
|
96a97adf9b | ||
|
|
e75c6126e9 | ||
|
|
cda6f5c66c | ||
|
|
bebb6823c0 | ||
|
|
31e472baa4 | ||
|
|
657685e85d | ||
|
|
a14912858e | ||
|
|
eed11ded30 | ||
|
|
b42aba40ed | ||
|
|
25885e5335 | ||
|
|
98d44fa39d | ||
|
|
2099e2d267 | ||
|
|
0c1041ad85 | ||
|
|
c245b0406f | ||
|
|
8b194b7520 | ||
|
|
3e8b8a1933 | ||
|
|
41dc280491 | ||
|
|
53d2990d9b | ||
|
|
e185c08ad9 | ||
|
|
2412adf42b | ||
|
|
be2ac1ed93 | ||
|
|
dc13813a03 | ||
|
|
d6af13efed | ||
|
|
a59f665235 |
@@ -23,6 +23,7 @@ set(GGML_SCHED_MAX_COPIES 4)
|
||||
set(GGML_LLAMAFILE ON)
|
||||
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
|
||||
set(GGML_CUDA_GRAPHS ON)
|
||||
set(GGML_CUDA_FA ON)
|
||||
|
||||
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
|
||||
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
|
||||
@@ -105,9 +106,11 @@ if(CMAKE_HIP_COMPILER)
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
|
||||
|
||||
if (WIN32)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
|
||||
endif()
|
||||
|
||||
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
|
||||
|
||||
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
|
||||
install(TARGETS ggml-hip
|
||||
RUNTIME_DEPENDENCIES
|
||||
|
||||
@@ -28,7 +28,7 @@
|
||||
"name": "CUDA 12",
|
||||
"inherits": [ "CUDA" ],
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
|
||||
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
|
||||
|
||||
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
|
||||
|
||||
## Pull requests
|
||||
|
||||
### Ideal issues
|
||||
|
||||
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
|
||||
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
|
||||
* Changes that add significant friction to the user experience
|
||||
* Changes that create a large future maintenance burden for maintainers and contributors
|
||||
|
||||
### Best practices
|
||||
## Proposing a (non-trivial) change
|
||||
|
||||
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
|
||||
* Tests: please add test coverage to changes where possible.
|
||||
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
|
||||
> By "non-trivial", we mean a change that is not a bug fix or small
|
||||
> documentation update. If you are unsure, please ask us on our [Discord
|
||||
> server](https://discord.gg/ollama).
|
||||
|
||||
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
|
||||
get feedback from the maintainers. This helps us understand the context of the
|
||||
change and how it fits into Ollama's roadmap and prevents us from duplicating
|
||||
work or you from spending time on a change that we may not be able to accept.
|
||||
|
||||
Tips for proposals:
|
||||
|
||||
* Explain the problem you are trying to solve, not what you are trying to do.
|
||||
* Explain why the change is important.
|
||||
* Explain how the change will be used.
|
||||
* Explain how the change will be tested.
|
||||
|
||||
Additionally, for bonus points: Provide draft documentation you would expect to
|
||||
see if the change were accepted.
|
||||
|
||||
## Pull requests
|
||||
|
||||
**Commit messages**
|
||||
|
||||
The title should look like:
|
||||
|
||||
<package>: <short description>
|
||||
|
||||
The package is the most affected Go package. If the change does not affect Go
|
||||
code, then use the directory name instead. Changes to a single well-known
|
||||
file in the root directory may use the file name.
|
||||
|
||||
The short description should start with a lowercase letter and be a
|
||||
continuation of the sentence:
|
||||
|
||||
"This changes Ollama to..."
|
||||
|
||||
Examples:
|
||||
|
||||
llm/backend/mlx: support the llama architecture
|
||||
CONTRIBUTING: provide clairity on good commit messages, and bad
|
||||
|
||||
Bad Examples:
|
||||
|
||||
feat: add more emoji
|
||||
fix: was not using famous web framework
|
||||
chore: generify code
|
||||
|
||||
**Tests**
|
||||
|
||||
Please include tests. Strive to test behavior, not implementation.
|
||||
|
||||
**New dependencies**
|
||||
|
||||
Dependencies should be added sparingly. If you are adding a new dependency,
|
||||
please explain why it is necessary and what other ways you attempted that
|
||||
did not work without it.
|
||||
|
||||
## Need help?
|
||||
|
||||
|
||||
@@ -12,8 +12,9 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
|
||||
RUN yum install -y yum-utils \
|
||||
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
|
||||
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
|
||||
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
|
||||
&& dnf install -y yum-utils gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
|
||||
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo \
|
||||
&& curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
|
||||
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
|
||||
|
||||
FROM --platform=linux/arm64 almalinux:8 AS base-arm64
|
||||
|
||||
@@ -386,6 +386,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
|
||||
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
|
||||
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
|
||||
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
|
||||
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
|
||||
|
||||
### Cloud
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@
|
||||
// repository].
|
||||
//
|
||||
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
|
||||
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
|
||||
package api
|
||||
|
||||
import (
|
||||
|
||||
@@ -73,6 +73,7 @@ func AllowedOrigins() (origins []string) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
)
|
||||
|
||||
return origins
|
||||
|
||||
@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://10.0.0.1", []string{
|
||||
"http://10.0.0.1",
|
||||
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://172.16.0.1,https://192.168.0.1", []string{
|
||||
"http://172.16.0.1",
|
||||
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
{"http://totally.safe,http://definitely.legit", []string{
|
||||
"http://totally.safe",
|
||||
@@ -128,6 +131,7 @@ func TestOrigins(t *testing.T) {
|
||||
"file://*",
|
||||
"tauri://*",
|
||||
"vscode-webview://*",
|
||||
"vscode-file://*",
|
||||
}},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
|
||||
@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
|
||||
return keyValue(kv, key, append(defaultValue, 0)...)
|
||||
}
|
||||
|
||||
func (kv KV) Bool(key string, defaultValue ...bool) bool {
|
||||
return keyValue(kv, key, append(defaultValue, false)...)
|
||||
}
|
||||
|
||||
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
|
||||
r := keyValue(kv, key, &array{})
|
||||
s := make([]string, r.size)
|
||||
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
|
||||
return s
|
||||
}
|
||||
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
|
||||
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
|
||||
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
|
||||
key = kv.Architecture() + "." + key
|
||||
}
|
||||
|
||||
16
go.mod
16
go.mod
@@ -1,6 +1,6 @@
|
||||
module github.com/ollama/ollama
|
||||
|
||||
go 1.24
|
||||
go 1.24.0
|
||||
|
||||
require (
|
||||
github.com/containerd/console v1.0.3
|
||||
@@ -11,7 +11,7 @@ require (
|
||||
github.com/spf13/cobra v1.7.0
|
||||
github.com/stretchr/testify v1.9.0
|
||||
github.com/x448/float16 v0.8.4
|
||||
golang.org/x/sync v0.10.0
|
||||
golang.org/x/sync v0.11.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -69,12 +69,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.31.0
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
|
||||
golang.org/x/net v0.25.0 // indirect
|
||||
golang.org/x/sys v0.28.0
|
||||
golang.org/x/term v0.27.0
|
||||
golang.org/x/text v0.21.0
|
||||
golang.org/x/crypto v0.33.0
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
|
||||
golang.org/x/net v0.35.0 // indirect
|
||||
golang.org/x/sys v0.30.0
|
||||
golang.org/x/term v0.29.0
|
||||
golang.org/x/text v0.22.0
|
||||
google.golang.org/protobuf v1.34.1
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
28
go.sum
28
go.sum
@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
|
||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
|
||||
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
|
||||
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
|
||||
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
|
||||
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
||||
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
||||
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
|
||||
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
|
||||
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
|
||||
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
|
||||
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
|
||||
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
|
||||
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
|
||||
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
|
||||
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
|
||||
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
|
||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
|
||||
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
|
||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
|
||||
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
|
||||
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -29,6 +29,17 @@ type Cache interface {
|
||||
// cache implementation used.
|
||||
Put(ctx ml.Context, key, value ml.Tensor)
|
||||
|
||||
// SetConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output of the cache to work better with specific kernels. If not called,
|
||||
// the backend settings will be used. This works well when calling Attention.
|
||||
//
|
||||
// The config can be overridden by models, especially if they require vanilla
|
||||
// output when implementing their own version of attention. To do this, pass
|
||||
// an empty ml.CacheConfig.
|
||||
//
|
||||
// Most models will not need to use this.
|
||||
SetConfig(ml.CacheConfig)
|
||||
|
||||
// ** cache management **
|
||||
|
||||
// Init sets up runtime parameters
|
||||
|
||||
@@ -22,6 +22,9 @@ type Causal struct {
|
||||
Capacity int32
|
||||
windowSize int32
|
||||
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
|
||||
}
|
||||
|
||||
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding == 0 {
|
||||
c.config.CachePadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskBatchPadding == 0 {
|
||||
c.config.MaskBatchPadding = 1
|
||||
}
|
||||
|
||||
if c.config.MaskDType == ml.DTypeOther {
|
||||
c.config.MaskDType = ml.DTypeF32
|
||||
}
|
||||
|
||||
c.DType = dtype
|
||||
c.Capacity = capacity
|
||||
c.cells = make([]cacheCell, capacity)
|
||||
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
|
||||
c.cells = make([]cacheCell, c.Capacity)
|
||||
c.cellRanges = make(map[int]cellRange)
|
||||
c.backend = backend
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *Causal) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *Causal) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
|
||||
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
|
||||
}
|
||||
|
||||
func roundDown(length, pad int) int {
|
||||
return (length / pad) * pad
|
||||
}
|
||||
|
||||
func roundUp(length, pad int) int {
|
||||
return ((length + pad - 1) / pad) * pad
|
||||
}
|
||||
|
||||
// Builds a mask of history x batch indicating whether for each token in the batch the
|
||||
// token in the history should apply. This is based on both the sequence and causality (the
|
||||
// position of the history is not ahead of the token in the batch).
|
||||
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
|
||||
// TODO(jessegross): This does not do padding, which is required for flash attention
|
||||
len := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, c.curBatchSize*len)
|
||||
// Align and pad the two dimensions as required by the backend
|
||||
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
|
||||
|
||||
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
|
||||
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
|
||||
|
||||
length := c.curCellRange.max - c.curCellRange.min + 1
|
||||
mask := make([]float32, batchSize*length)
|
||||
|
||||
for i := range c.curBatchSize {
|
||||
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
|
||||
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
|
||||
c.cells[j].pos < positions[i]-c.windowSize {
|
||||
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
|
||||
// Mask out any padding tokens we added. For padding that we added to the cache history, this
|
||||
// has already been masked out because the sequence doesn't match.
|
||||
for i := c.curBatchSize * length; i < len(mask); i++ {
|
||||
mask[i] = float32(math.Inf(-1))
|
||||
}
|
||||
|
||||
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
}
|
||||
|
||||
return maskTensor, nil
|
||||
}
|
||||
|
||||
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
|
||||
for _, obj := range objs {
|
||||
if obj == nil {
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
for i := range c.keys {
|
||||
if c.keys[i] == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
|
||||
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
|
||||
key := c.keys[i]
|
||||
|
||||
ctx.Forward(srcView.Copy(ctx, dstView))
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
|
||||
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
|
||||
|
||||
value := c.values[i]
|
||||
var vSrcView, vDstView ml.Tensor
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
|
||||
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
kSrcView.Copy(ctx, kDstView),
|
||||
vSrcView.Copy(ctx, vDstView),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
|
||||
pendingLen++
|
||||
break
|
||||
} else {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
}
|
||||
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
|
||||
}
|
||||
|
||||
if pendingLen > 0 {
|
||||
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
|
||||
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
|
||||
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
|
||||
moves++
|
||||
}
|
||||
|
||||
@@ -305,33 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
key := c.keys[c.curLayer]
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
cachedSize := c.curMask.Dim(0)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
|
||||
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
|
||||
value.Dim(0), value.Stride(1),
|
||||
value.Dim(1), value.Stride(2),
|
||||
c.curMask.Dim(0),
|
||||
)
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
cachedSize, value.Stride(1),
|
||||
vHeadDim, value.Stride(2),
|
||||
numKVHeads,
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(0)
|
||||
rowSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
vHeadDim, value.Stride(1),
|
||||
numKVHeads, value.Stride(2),
|
||||
cachedSize,
|
||||
)
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
}
|
||||
|
||||
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
if c.curBatchSize != key.Dim(2) {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
|
||||
kHeadDim := key.Dim(0)
|
||||
vHeadDim := value.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
batchSize := key.Dim(2)
|
||||
|
||||
if c.curBatchSize != batchSize {
|
||||
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
|
||||
}
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
|
||||
|
||||
if c.config.PermutedV {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
|
||||
} else {
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
|
||||
rowSize := c.keys[c.curLayer].Stride(2)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
@@ -387,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
continue
|
||||
}
|
||||
|
||||
key = key.View(ctx, key.Stride(2)*seqRange.min,
|
||||
key.Dim(0), key.Stride(1),
|
||||
key.Dim(1), key.Stride(2),
|
||||
kHeadDim := key.Dim(0)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(2)
|
||||
|
||||
key = key.View(ctx, rowSize*seqRange.min,
|
||||
kHeadDim, key.Stride(1),
|
||||
numKVHeads, key.Stride(2),
|
||||
size,
|
||||
)
|
||||
|
||||
|
||||
@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
|
||||
|
||||
out, _, mask := cache.Get(context)
|
||||
|
||||
context.Forward(out)
|
||||
context.Forward(mask)
|
||||
context.Compute(out, mask)
|
||||
context.Forward(out, mask).Compute(out, mask)
|
||||
|
||||
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
|
||||
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
|
||||
@@ -311,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
|
||||
|
||||
type testContext struct{}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
total := 0
|
||||
|
||||
if len(shape) > 0 {
|
||||
@@ -324,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
|
||||
}
|
||||
|
||||
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return c.Empty(dtype, shape...)
|
||||
}
|
||||
|
||||
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
|
||||
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
|
||||
|
||||
copy(t.data, s)
|
||||
|
||||
@@ -344,7 +346,7 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (c *testContext) Forward(ml.Tensor) {}
|
||||
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
|
||||
|
||||
func (c *testContext) Compute(...ml.Tensor) {}
|
||||
|
||||
@@ -393,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
|
||||
}
|
||||
|
||||
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
|
||||
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
|
||||
|
||||
for i := range out.data {
|
||||
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
|
||||
@@ -470,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
|
||||
context := &testContext{}
|
||||
|
||||
view := context.Zeros(t.dtype, s...).(*testTensor)
|
||||
view := context.Empty(t.dtype, s...).(*testTensor)
|
||||
view.data = t.data[offset : offset+len(view.data)]
|
||||
|
||||
return view
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package kvcache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,6 +13,9 @@ import (
|
||||
//
|
||||
// Not currently safe for multiple sequences
|
||||
type EncoderCache struct {
|
||||
// config controls mostly backend-specific optimizations
|
||||
config *ml.CacheConfig
|
||||
|
||||
// ** current forward pass **
|
||||
|
||||
// the active layer for Get and Put
|
||||
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
|
||||
if c.config == nil {
|
||||
var config ml.CacheConfig
|
||||
if cc, ok := backend.(ml.BackendCacheConfig); ok {
|
||||
config = cc.CacheConfig()
|
||||
}
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
|
||||
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
|
||||
}
|
||||
|
||||
c.cacheCtx = backend.NewContext()
|
||||
}
|
||||
|
||||
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
|
||||
if c.config != nil {
|
||||
panic("config cannot be changed after being previously set, either by the model or backend")
|
||||
}
|
||||
|
||||
c.config = &config
|
||||
}
|
||||
|
||||
func (c *EncoderCache) Close() {
|
||||
c.cacheCtx.Close()
|
||||
}
|
||||
@@ -75,13 +100,19 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
c.encoderPos = c.curPos
|
||||
c.encoderCached = true
|
||||
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
|
||||
if c.config.PermutedV {
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
}
|
||||
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
|
||||
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
|
||||
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
|
||||
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
|
||||
}
|
||||
|
||||
ctx.Forward(
|
||||
key.Copy(ctx, c.keys[c.curLayer]),
|
||||
value.Copy(ctx, c.values[c.curLayer]),
|
||||
)
|
||||
}
|
||||
|
||||
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {
|
||||
|
||||
@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
|
||||
for _, cache := range c.caches {
|
||||
cache.SetConfig(config)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WrapperCache) Close() {
|
||||
for _, cache := range c.caches {
|
||||
cache.Close()
|
||||
|
||||
1
llama/llama.cpp/include/llama.h
vendored
1
llama/llama.cpp/include/llama.h
vendored
@@ -105,6 +105,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
|
||||
10
llama/llama.cpp/src/llama-model.cpp
vendored
10
llama/llama.cpp/src/llama-model.cpp
vendored
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
|
||||
11
llama/llama.cpp/src/llama-vocab.cpp
vendored
11
llama/llama.cpp/src/llama-vocab.cpp
vendored
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
||||
// original regex from tokenizer.json
|
||||
// [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
|
||||
regex_exprs = {
|
||||
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
} else if (
|
||||
tokenizer_pre == "gpt-4o") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
|
||||
@@ -37,23 +37,36 @@ COMPILER inline get_compiler() {
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"runtime/cgo"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/common"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
|
||||
_ "github.com/ollama/ollama/llama/llama.cpp/src"
|
||||
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
|
||||
)
|
||||
|
||||
func init() {
|
||||
C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
|
||||
}
|
||||
|
||||
//export llamaLog
|
||||
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
|
||||
// slog levels zeros INFO and are multiples of 4
|
||||
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
}
|
||||
|
||||
func BackendInit() {
|
||||
ggml.OnceLoad()
|
||||
C.llama_backend_init()
|
||||
@@ -72,26 +85,6 @@ func PrintSystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info()) + compiler
|
||||
}
|
||||
|
||||
var logLevel atomic.Int32
|
||||
|
||||
func init() {
|
||||
logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
|
||||
C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
|
||||
}
|
||||
|
||||
func EnableDebug() {
|
||||
logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
|
||||
}
|
||||
|
||||
//export llamaLog
|
||||
func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
|
||||
if level < logLevel.Load() {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
|
||||
func GetModelArch(modelPath string) (string, error) {
|
||||
mp := C.CString(modelPath)
|
||||
defer C.free(unsafe.Pointer(mp))
|
||||
@@ -269,7 +262,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
|
||||
}
|
||||
|
||||
m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
|
||||
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
|
||||
if m.c == nil {
|
||||
return nil, fmt.Errorf("unable to load model: %s", modelPath)
|
||||
}
|
||||
@@ -278,12 +271,12 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
|
||||
}
|
||||
|
||||
func FreeModel(model *Model) {
|
||||
C.llama_free_model(model.c)
|
||||
C.llama_model_free(model.c)
|
||||
}
|
||||
|
||||
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
|
||||
c := Context{
|
||||
c: C.llama_new_context_with_model(model.c, params.c),
|
||||
c: C.llama_init_from_model(model.c, params.c),
|
||||
numThreads: int(params.c.n_threads),
|
||||
}
|
||||
if c.c == nil {
|
||||
@@ -294,15 +287,15 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
|
||||
}
|
||||
|
||||
func (m *Model) NumVocab() int {
|
||||
return int(C.llama_n_vocab(m.Vocab()))
|
||||
return int(C.llama_vocab_n_tokens(m.Vocab()))
|
||||
}
|
||||
|
||||
func (m *Model) TokenIsEog(token int) bool {
|
||||
return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
|
||||
return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
|
||||
}
|
||||
|
||||
func (m *Model) AddBOSToken() bool {
|
||||
return bool(C.llama_add_bos_token(m.Vocab()))
|
||||
return bool(C.llama_vocab_get_add_bos(m.Vocab()))
|
||||
}
|
||||
|
||||
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
|
||||
@@ -485,7 +478,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
|
||||
}
|
||||
|
||||
func (m *Model) NEmbd() int {
|
||||
return int(C.llama_n_embd(m.c))
|
||||
return int(C.llama_model_n_embd(m.c))
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile string, ftype uint32) error {
|
||||
|
||||
80
llama/patches/0019-add-phi4-support.patch
Normal file
80
llama/patches/0019-add-phi4-support.patch
Normal file
@@ -0,0 +1,80 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Thu, 27 Feb 2025 15:12:26 -0800
|
||||
Subject: [PATCH] add phi4 support
|
||||
|
||||
---
|
||||
include/llama.h | 1 +
|
||||
src/llama-model.cpp | 10 +++++++---
|
||||
src/llama-vocab.cpp | 11 +++++++++++
|
||||
3 files changed, 19 insertions(+), 3 deletions(-)
|
||||
|
||||
diff --git a/include/llama.h b/include/llama.h
|
||||
index cc948005..16774711 100644
|
||||
--- a/include/llama.h
|
||||
+++ b/include/llama.h
|
||||
@@ -105,6 +105,7 @@ extern "C" {
|
||||
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
|
||||
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
|
||||
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
|
||||
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
|
||||
};
|
||||
|
||||
enum llama_rope_type {
|
||||
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
|
||||
index 21819080..ab1a07d1 100644
|
||||
--- a/src/llama-model.cpp
|
||||
+++ b/src/llama-model.cpp
|
||||
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
|
||||
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
|
||||
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
+ // if output is NULL, init from the input tok embed
|
||||
+ if (output == NULL) {
|
||||
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
+ }
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
|
||||
|
||||
- layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
- layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
+ layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
+ layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_PHIMOE:
|
||||
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
|
||||
index 1ca827eb..c7ff28be 100644
|
||||
--- a/src/llama-vocab.cpp
|
||||
+++ b/src/llama-vocab.cpp
|
||||
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
|
||||
};
|
||||
break;
|
||||
+ case LLAMA_VOCAB_PRE_TYPE_GPT4O:
|
||||
+ // original regex from tokenizer.json
|
||||
+ // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
|
||||
+ regex_exprs = {
|
||||
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||
+ };
|
||||
+ break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||
} else if (
|
||||
tokenizer_pre == "megrez") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
|
||||
+ } else if (
|
||||
+ tokenizer_pre == "gpt-4o") {
|
||||
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
|
||||
+ clean_spaces = false;
|
||||
} else {
|
||||
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;
|
||||
@@ -14,6 +14,7 @@ type Config interface {
|
||||
String(string, ...string) string
|
||||
Uint(string, ...uint32) uint32
|
||||
Float(string, ...float32) float32
|
||||
Bool(string, ...bool) bool
|
||||
|
||||
Strings(string, ...[]string) []string
|
||||
Uints(string, ...[]uint32) []uint32
|
||||
@@ -26,6 +27,35 @@ type Backend interface {
|
||||
SystemInfo() string
|
||||
}
|
||||
|
||||
// BackendCacheConfig should be implemented by backends that need special output
|
||||
// from the cache to meet specific requirements. It is frequently implemented in
|
||||
// conjunction with ScaledDotProductAttention.
|
||||
type BackendCacheConfig interface {
|
||||
CacheConfig() CacheConfig
|
||||
}
|
||||
|
||||
// CacheConfig controls optimizations (mostly backend-specific) that may transform
|
||||
// the output the cache to work better with specific kernels.
|
||||
type CacheConfig struct {
|
||||
// CachePadding specifies the multiple for the number of tokens of cache history
|
||||
// that will be returned from cache Get for k, v and mask. The capacity of the
|
||||
// cache itself will also be increased to a multiple of this size if needed.
|
||||
CachePadding int
|
||||
|
||||
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
|
||||
// and return the permuted version via Get. This uses the cache copy operation
|
||||
// to avoid a Contiguous call on the permuted tensor.
|
||||
PermutedV bool
|
||||
|
||||
// MaskDType specifies the data type for generating the mask. If unset it will
|
||||
// default to DTypeF32.
|
||||
MaskDType DType
|
||||
|
||||
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
|
||||
// Any position that does not correspond to an actual token will be filled with -Inf.
|
||||
MaskBatchPadding int
|
||||
}
|
||||
|
||||
// BackendParams controls how the backend loads and executes models
|
||||
type BackendParams struct {
|
||||
// NumThreads sets the number of threads to use if running on the CPU
|
||||
@@ -39,6 +69,9 @@ type BackendParams struct {
|
||||
|
||||
// TensorSplit is the fraction of the model to offload to each GPU
|
||||
TensorSplit []float32
|
||||
|
||||
// FlashAttention indicates that we should use a fused flash attention kernel
|
||||
FlashAttention bool
|
||||
}
|
||||
|
||||
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
|
||||
@@ -60,11 +93,12 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
}
|
||||
|
||||
type Context interface {
|
||||
Empty(dtype DType, shape ...int) Tensor
|
||||
Zeros(dtype DType, shape ...int) Tensor
|
||||
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
|
||||
FromIntSlice(s []int32, shape ...int) (Tensor, error)
|
||||
|
||||
Forward(Tensor)
|
||||
Forward(...Tensor) Context
|
||||
Compute(...Tensor)
|
||||
MaxTensors() int
|
||||
Close()
|
||||
@@ -115,6 +149,10 @@ type Tensor interface {
|
||||
// operation equivalent to following code on a tensor named
|
||||
// query:
|
||||
//
|
||||
// query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
// key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
//
|
||||
// kq := key.MulmatFullPrec(ctx, query)
|
||||
//
|
||||
// kq = kq.Scale(ctx, scale)
|
||||
@@ -169,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
})
|
||||
case DTypeF16:
|
||||
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
|
||||
f32 := ctx.Empty(DTypeF32, t.Shape()...)
|
||||
f32 = t.Copy(ctx, f32)
|
||||
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
|
||||
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
|
||||
@@ -185,8 +223,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
|
||||
|
||||
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
|
||||
if t.Bytes() == nil {
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
}
|
||||
|
||||
s := make(S, mul(t.Shape()...))
|
||||
|
||||
@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
|
||||
})
|
||||
|
||||
type Backend struct {
|
||||
flashAttention bool
|
||||
|
||||
meta *fs.GGML
|
||||
cpus, gpus []Context
|
||||
tensors map[string]*Context
|
||||
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
|
||||
}
|
||||
|
||||
return &Backend{
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
flashAttention: params.FlashAttention,
|
||||
meta: meta,
|
||||
cpus: cpus,
|
||||
gpus: gpus,
|
||||
sched: C.ggml_backend_sched_new(
|
||||
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
|
||||
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
|
||||
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
|
||||
|
||||
for _, c := range append(b.gpus, b.cpus...) {
|
||||
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
|
||||
return &Tensor{t: t}
|
||||
return &Tensor{b: b, t: t}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) CacheConfig() ml.CacheConfig {
|
||||
if b.flashAttention {
|
||||
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
|
||||
} else {
|
||||
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
|
||||
}
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
b *Backend
|
||||
ctx *C.struct_ggml_context
|
||||
@@ -256,12 +267,16 @@ type Context struct {
|
||||
nodes int
|
||||
}
|
||||
|
||||
func (c *Context) Forward(t ml.Tensor) {
|
||||
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
|
||||
if c.graph == nil {
|
||||
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
|
||||
}
|
||||
|
||||
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
|
||||
for _, tensor := range tensors {
|
||||
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
|
||||
}
|
||||
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Context) Compute(tensors ...ml.Tensor) {
|
||||
@@ -296,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
|
||||
return &sh[0]
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
|
||||
if len(shape) < 1 || len(shape) > 4 {
|
||||
panic("unsupported number of dimensions")
|
||||
}
|
||||
@@ -310,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
var t *C.struct_ggml_tensor
|
||||
switch dtype {
|
||||
case ml.DTypeF32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeF16:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
|
||||
case ml.DTypeI32:
|
||||
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
|
||||
default:
|
||||
panic("unsupported dtype")
|
||||
}
|
||||
|
||||
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_set_zero(t)
|
||||
return &Tensor{t: t}
|
||||
if zero {
|
||||
C.ggml_set_zero(t)
|
||||
}
|
||||
return &Tensor{b: ctx.b, t: t}
|
||||
}
|
||||
|
||||
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, false, shape)
|
||||
}
|
||||
|
||||
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
|
||||
return newTensor(c, dtype, true, shape)
|
||||
}
|
||||
|
||||
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
|
||||
@@ -331,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
if n == 0 {
|
||||
var shape C.int64_t = 0
|
||||
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
for _, v := range shape {
|
||||
@@ -346,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
|
||||
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
|
||||
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
|
||||
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
|
||||
return &Tensor{t: t}, nil
|
||||
return &Tensor{b: ctx.b, t: t}, nil
|
||||
}
|
||||
|
||||
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
|
||||
@@ -364,6 +389,7 @@ func (c *Context) Close() {
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
b *Backend
|
||||
t *C.struct_ggml_tensor
|
||||
sync func()
|
||||
}
|
||||
@@ -430,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
|
||||
|
||||
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -444,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -471,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: mul,
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
|
||||
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
if b != nil {
|
||||
tt = tt.Add(ctx, b)
|
||||
}
|
||||
@@ -485,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
|
||||
}
|
||||
|
||||
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
|
||||
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
|
||||
}
|
||||
|
||||
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
@@ -494,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -504,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
|
||||
}
|
||||
}
|
||||
@@ -524,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
|
||||
}
|
||||
case 2:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
|
||||
}
|
||||
case 4:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
|
||||
}
|
||||
default:
|
||||
@@ -545,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
|
||||
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
@@ -567,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
|
||||
}
|
||||
}
|
||||
@@ -575,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
switch len(shape) {
|
||||
case 1:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
|
||||
}
|
||||
case 3:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]),
|
||||
C.size_t(shape[1]),
|
||||
@@ -586,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 5:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]),
|
||||
@@ -593,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
|
||||
}
|
||||
case 7:
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
|
||||
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
|
||||
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
|
||||
@@ -609,7 +657,7 @@ const (
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{}
|
||||
ropeFactors = &Tensor{b: t.b}
|
||||
}
|
||||
|
||||
dequant := t.t
|
||||
@@ -618,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
}
|
||||
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_rope_ext(
|
||||
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
|
||||
C.int(ropeDim),
|
||||
@@ -635,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
|
||||
|
||||
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
|
||||
return &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
|
||||
}
|
||||
}
|
||||
@@ -657,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
|
||||
kqMask = mask.(*Tensor).t
|
||||
}
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, t)
|
||||
kq = &Tensor{
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
query := t.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
if t.b.flashAttention {
|
||||
value = value.Permute(ctx, 0, 2, 1, 3)
|
||||
|
||||
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
|
||||
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
|
||||
return &Tensor{b: t.b, t: kqv}
|
||||
} else {
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
kq = &Tensor{
|
||||
b: t.b,
|
||||
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
|
||||
}
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Backend) SystemInfo() string {
|
||||
|
||||
@@ -10,6 +10,8 @@ package ggml
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
@@ -22,21 +24,14 @@ import (
|
||||
)
|
||||
|
||||
func init() {
|
||||
C.ggml_log_set((C.ggml_log_callback)(C.sink), nil)
|
||||
C.ggml_log_set(C.ggml_log_callback(C.sink), nil)
|
||||
}
|
||||
|
||||
//export sink
|
||||
func sink(level C.int, text *C.char, _ unsafe.Pointer) {
|
||||
msg := strings.TrimSpace(C.GoString(text))
|
||||
switch level {
|
||||
case C.GGML_LOG_LEVEL_DEBUG:
|
||||
slog.Debug(msg)
|
||||
case C.GGML_LOG_LEVEL_INFO:
|
||||
slog.Info(msg)
|
||||
case C.GGML_LOG_LEVEL_WARN:
|
||||
slog.Warn(msg)
|
||||
case C.GGML_LOG_LEVEL_ERROR:
|
||||
slog.Error(msg)
|
||||
// slog levels zeros INFO and are multiples of 4
|
||||
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
|
||||
fmt.Fprint(os.Stderr, C.GoString(text))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package nn
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
@@ -11,40 +12,50 @@ import (
|
||||
//
|
||||
// Parameters:
|
||||
// - ctx: Context for tensor operations
|
||||
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
|
||||
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
|
||||
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
|
||||
// - mask: Optional attention mask that is added to the attention score. If
|
||||
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
|
||||
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
|
||||
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
|
||||
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
|
||||
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
|
||||
//
|
||||
// Returns:
|
||||
//
|
||||
// Attention output with shape [d_v, heads, seq_len_q]
|
||||
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
|
||||
if key != nil && value != nil {
|
||||
if query.Dim(0) != key.Dim(0) {
|
||||
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(1) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if cache != nil {
|
||||
cache.Put(ctx, key, value)
|
||||
}
|
||||
} else if cache == nil {
|
||||
panic("key & value tensors must be provided if cache is nil")
|
||||
}
|
||||
|
||||
if mask != nil && query.Dim(1) != mask.Dim(1) {
|
||||
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
|
||||
var mask ml.Tensor
|
||||
if cache != nil {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
if key.Dim(1) != value.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
|
||||
}
|
||||
|
||||
if mask != nil && key.Dim(1) != mask.Dim(0) {
|
||||
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
|
||||
}
|
||||
|
||||
if key.Dim(2) != value.Dim(2) {
|
||||
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
|
||||
}
|
||||
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
|
||||
// Only use the fast SDPA implementation if we have a cache, since that's what
|
||||
// will do any expected backend-specific transformations for us
|
||||
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
|
||||
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
|
||||
} else {
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scale)
|
||||
|
||||
@@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
ctx.Forward(t)
|
||||
ctx.Compute(t)
|
||||
ctx.Forward(t).Compute(t)
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
@@ -37,7 +37,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
Layers: make([]Layer, c.Uint("block_count")),
|
||||
@@ -79,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, k, v)
|
||||
k, v, mask := cache.Get(ctx)
|
||||
|
||||
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
|
||||
@@ -33,7 +33,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
Types: c.Uints("tokenizer.ggml.token_type"),
|
||||
Merges: c.Strings("tokenizer.ggml.merges"),
|
||||
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
|
||||
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
|
||||
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
|
||||
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
|
||||
},
|
||||
),
|
||||
ImageProcessor: newImageProcessor(c),
|
||||
@@ -41,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
|
||||
TextModel: newTextModel(c),
|
||||
}
|
||||
|
||||
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
encoderCache := kvcache.NewEncoderCache()
|
||||
encoderCache.SetConfig(ml.CacheConfig{})
|
||||
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
|
||||
|
||||
return &m, nil
|
||||
}
|
||||
|
||||
@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
key, value, mask := cache.Get(ctx)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return sa.Output.Forward(ctx, attention)
|
||||
}
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
// This will only get called for layers in the causal cache, which are just the self attention layers
|
||||
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
}
|
||||
|
||||
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
|
||||
|
||||
var key, value, mask ml.Tensor
|
||||
var key, value ml.Tensor
|
||||
if crossAttentionStates != nil {
|
||||
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
|
||||
|
||||
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
|
||||
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
|
||||
|
||||
cache.Put(ctx, key, value)
|
||||
} else {
|
||||
key, value, mask = cache.Get(ctx)
|
||||
}
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
key, value, _ = cache.Get(ctx)
|
||||
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
|
||||
|
||||
query = query.Permute(ctx, 0, 2, 1, 3)
|
||||
key = key.Permute(ctx, 0, 2, 1, 3)
|
||||
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
|
||||
|
||||
kq := key.MulmatFullPrec(ctx, query)
|
||||
|
||||
kq = kq.Scale(ctx, scaleFactor)
|
||||
kq = kq.Softmax(ctx)
|
||||
|
||||
kqv := value.Mulmat(ctx, kq)
|
||||
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
|
||||
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
return ca.Output.Forward(ctx, attention)
|
||||
|
||||
@@ -30,7 +30,8 @@ type Vocabulary struct {
|
||||
Scores []uint32
|
||||
Merges []string
|
||||
|
||||
BOS, EOS int32
|
||||
BOS, EOS int32
|
||||
AddBOS, AddEOS bool
|
||||
|
||||
specialOnce sync.Once
|
||||
special []string
|
||||
@@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if len(ids) > 0 {
|
||||
if bpe.vocab.AddBOS {
|
||||
if ids[0] == bpe.vocab.BOS {
|
||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
|
||||
ids = append([]int32{bpe.vocab.BOS}, ids...)
|
||||
}
|
||||
|
||||
if bpe.vocab.AddEOS {
|
||||
if ids[len(ids)-1] == bpe.vocab.EOS {
|
||||
slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
|
||||
}
|
||||
|
||||
slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
|
||||
ids = append(ids, bpe.vocab.EOS)
|
||||
}
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -915,7 +915,6 @@ func Execute(args []string) error {
|
||||
level := slog.LevelInfo
|
||||
if *verbose {
|
||||
level = slog.LevelDebug
|
||||
llama.EnableDebug()
|
||||
}
|
||||
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
Level: level,
|
||||
@@ -944,12 +943,11 @@ func Execute(args []string) error {
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||
for _, s := range stringFloats {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -970,13 +968,14 @@ func Execute(args []string) error {
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
fmt.Println("Listen error:", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
@@ -996,6 +995,5 @@ func Execute(args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -575,23 +575,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
sampler, err := sample.NewSampler(
|
||||
req.Temperature,
|
||||
req.TopK,
|
||||
req.TopP,
|
||||
req.MinP,
|
||||
req.Seed,
|
||||
)
|
||||
if err != nil {
|
||||
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
|
||||
numPredict: req.NumPredict,
|
||||
stop: req.Stop,
|
||||
numKeep: int32(req.NumKeep),
|
||||
sampler: sampler,
|
||||
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
|
||||
embedding: false,
|
||||
})
|
||||
if err != nil {
|
||||
@@ -830,7 +818,7 @@ func Execute(args []string) error {
|
||||
batchSize := fs.Int("batch-size", 512, "Batch size")
|
||||
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
|
||||
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
|
||||
_ = fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
|
||||
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
|
||||
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
|
||||
port := fs.Int("port", 8080, "Port to expose the server on")
|
||||
@@ -875,26 +863,25 @@ func Execute(args []string) error {
|
||||
}
|
||||
|
||||
// TODO(jessegross): Parameters that need to be implemented:
|
||||
// flash-attn
|
||||
// no-mmap
|
||||
// mlock
|
||||
|
||||
var tensorSplitFloats []float32
|
||||
if *tensorSplit != "" {
|
||||
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
|
||||
|
||||
tensorSplitFloats = make([]float32, 0, len(stringFloats))
|
||||
for _, s := range stringFloats {
|
||||
splits := strings.Split(*tensorSplit, ",")
|
||||
tensorSplitFloats = make([]float32, len(splits))
|
||||
for i, s := range splits {
|
||||
f, _ := strconv.ParseFloat(s, 32)
|
||||
tensorSplitFloats = append(tensorSplitFloats, float32(f))
|
||||
tensorSplitFloats[i] = float32(f)
|
||||
}
|
||||
}
|
||||
|
||||
params := ml.BackendParams{
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
NumThreads: *threads,
|
||||
NumGPULayers: *numGPULayers,
|
||||
MainGPU: *mainGPU,
|
||||
TensorSplit: tensorSplitFloats,
|
||||
FlashAttention: *flashAttention,
|
||||
}
|
||||
|
||||
server.ready.Add(1)
|
||||
@@ -903,13 +890,14 @@ func Execute(args []string) error {
|
||||
server.cond = sync.NewCond(&server.mu)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
go server.run(ctx)
|
||||
|
||||
addr := "127.0.0.1:" + strconv.Itoa(*port)
|
||||
listener, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
fmt.Println("Listen error:", err)
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
@@ -929,6 +917,5 @@ func Execute(args []string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
|
||||
if idx, ok := w.Take(); ok {
|
||||
return int32(indices[idx]), nil
|
||||
}
|
||||
return -1, errors.New("weighed sampler failed, no valid token found")
|
||||
return -1, errors.New("weighted sampler failed, no valid token found")
|
||||
}
|
||||
|
||||
type greedy struct {
|
||||
transforms []Transform
|
||||
}
|
||||
|
||||
func Greedy(transforms ...Transform) Sampler {
|
||||
return greedy{transforms: transforms}
|
||||
type greedy struct{}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
// Sample returns the index of the maximum value in logits.
|
||||
func (s greedy) Sample(logits []float32) (int32, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("no logits provided for greedy sampling")
|
||||
}
|
||||
|
||||
for _, t := range s.transforms {
|
||||
logits64 = t.Apply(logits64)
|
||||
}
|
||||
|
||||
var maxIdx int
|
||||
var maxLogit float64
|
||||
for i, logit := range logits64 {
|
||||
if logit > maxLogit {
|
||||
maxLogit = logit
|
||||
maxIdx := 0
|
||||
for i := range logits {
|
||||
if logits[i] > logits[maxIdx] {
|
||||
maxIdx = i
|
||||
}
|
||||
}
|
||||
|
||||
if maxLogit == math.Inf(-1) {
|
||||
return -1, errors.New("no valid logits found for greedy sampling")
|
||||
}
|
||||
|
||||
return int32(maxIdx), nil
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
|
||||
transforms := []Transform{}
|
||||
if temperature == 0 {
|
||||
return Greedy(), nil
|
||||
}
|
||||
|
||||
if temperature < 0 || temperature > 2 {
|
||||
return nil, errors.New("temperature must be between 0 and 2")
|
||||
}
|
||||
|
||||
if temperature != 0 {
|
||||
transforms = append(transforms, Temperature(temperature))
|
||||
}
|
||||
transforms := []Transform{Temperature(temperature)}
|
||||
|
||||
if topK != 0 {
|
||||
if topK <= 0 {
|
||||
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
transforms = append(transforms, MinP(minP))
|
||||
}
|
||||
|
||||
if len(transforms) == 0 {
|
||||
return nil, errors.New("at least one transform is required")
|
||||
}
|
||||
|
||||
if temperature == 0 {
|
||||
return Greedy(transforms...), nil
|
||||
}
|
||||
|
||||
if seed != 0 {
|
||||
if seed >= 0 {
|
||||
seed64 := uint64(seed)
|
||||
return Weighted(&seed64, transforms...), nil
|
||||
}
|
||||
|
||||
@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
|
||||
got, err := Greedy(mock1, mock2, mock3).Sample(input)
|
||||
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
want := int32(3) // Greedy sampler should pick highest logit
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
callOrder = nil
|
||||
|
||||
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder = []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewSampler(t *testing.T) {
|
||||
@@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) {
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no transforms",
|
||||
wantErr: true,
|
||||
name: "no transforms",
|
||||
// temperature is 0, so greedy should be used
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "temperature",
|
||||
@@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
wantErr: false,
|
||||
name: "top k",
|
||||
topK: 10,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
wantErr: true,
|
||||
name: "invalid top k negative",
|
||||
topK: -1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
wantErr: false,
|
||||
name: "top p",
|
||||
topP: 0.9,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
wantErr: true,
|
||||
name: "invalid top p negative",
|
||||
topP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
wantErr: true,
|
||||
name: "invalid top p one",
|
||||
topP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
wantErr: false,
|
||||
name: "min p",
|
||||
minP: 0.2,
|
||||
temperature: 0.8,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
wantErr: true,
|
||||
name: "invalid min p negative",
|
||||
minP: -0.1,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "seed",
|
||||
seed: 42,
|
||||
wantErr: true, // seed alone is not valid without other transforms
|
||||
name: "invalid min p one",
|
||||
minP: 1.0,
|
||||
temperature: 0.8,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "default values",
|
||||
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
|
||||
topP: 0.0,
|
||||
minP: 0.0,
|
||||
seed: 0,
|
||||
wantErr: true, // all zeroes means no transforms
|
||||
wantErr: false, // all zeroes means no transforms
|
||||
},
|
||||
{
|
||||
name: "all transforms",
|
||||
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
|
||||
}
|
||||
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(transforms...),
|
||||
"Greedy": Greedy(),
|
||||
"Weighted": Weighted(nil, transforms...),
|
||||
}
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -34,6 +35,7 @@ var (
|
||||
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
|
||||
errUnknownType = errors.New("unknown type")
|
||||
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
|
||||
errFilePath = errors.New("file path must be relative")
|
||||
)
|
||||
|
||||
func (s *Server) CreateHandler(c *gin.Context) {
|
||||
@@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
for v := range r.Files {
|
||||
if !fs.ValidPath(v) {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
name := model.ParseName(cmp.Or(r.Model, r.Name))
|
||||
if !name.IsValid() {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
|
||||
@@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
|
||||
if r.Adapters != nil {
|
||||
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
|
||||
if err != nil {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
|
||||
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
|
||||
if errors.Is(err, badReq) {
|
||||
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
|
||||
return
|
||||
@@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
return nil, err
|
||||
}
|
||||
defer os.RemoveAll(tmpDir)
|
||||
// Set up a root to validate paths
|
||||
root, err := os.OpenRoot(tmpDir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer root.Close()
|
||||
|
||||
for fp, digest := range files {
|
||||
if !fs.ValidPath(fp) {
|
||||
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
|
||||
}
|
||||
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
|
||||
// Path is likely outside the root
|
||||
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
|
||||
}
|
||||
|
||||
blobPath, err := GetBlobsPath(digest)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
f, _, err := ggml.Decode(bin, 0)
|
||||
if err != nil {
|
||||
|
||||
106
server/create_test.go
Normal file
106
server/create_test.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestConvertFromSafetensors(t *testing.T) {
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
|
||||
// Helper function to create a new layer and return its digest
|
||||
makeTemp := func(content string) string {
|
||||
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to create layer: %v", err)
|
||||
}
|
||||
return l.Digest
|
||||
}
|
||||
|
||||
// Create a safetensors compatible file with empty JSON content
|
||||
var buf bytes.Buffer
|
||||
headerSize := int64(len("{}"))
|
||||
binary.Write(&buf, binary.LittleEndian, headerSize)
|
||||
buf.WriteString("{}")
|
||||
|
||||
model := makeTemp(buf.String())
|
||||
config := makeTemp(`{
|
||||
"architectures": ["LlamaForCausalLM"],
|
||||
"vocab_size": 32000
|
||||
}`)
|
||||
tokenizer := makeTemp(`{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [
|
||||
{
|
||||
"id": 0,
|
||||
"content": "<|endoftext|>",
|
||||
"single_word": false,
|
||||
"lstrip": false,
|
||||
"rstrip": false,
|
||||
"normalized": false,
|
||||
"special": true
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
filePath string
|
||||
wantErr error
|
||||
}{
|
||||
// Invalid
|
||||
{
|
||||
name: "InvalidRelativePathShallow",
|
||||
filePath: filepath.Join("..", "file.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidRelativePathDeep",
|
||||
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "InvalidNestedPath",
|
||||
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
|
||||
wantErr: errFilePath,
|
||||
},
|
||||
{
|
||||
name: "AbsolutePathOutsideRoot",
|
||||
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
|
||||
wantErr: errFilePath, // Should fail since it's outside tmpDir
|
||||
},
|
||||
{
|
||||
name: "ValidRelativePath",
|
||||
filePath: "model.safetensors",
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create the minimum required file map for convertFromSafetensors
|
||||
files := map[string]string{
|
||||
tt.filePath: model,
|
||||
"config.json": config,
|
||||
"tokenizer.json": tokenizer,
|
||||
}
|
||||
|
||||
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
|
||||
|
||||
if (tt.wantErr == nil && err != nil) ||
|
||||
(tt.wantErr != nil && err == nil) ||
|
||||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
|
||||
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
28
server/internal/cache/blob/cache.go
vendored
28
server/internal/cache/blob/cache.go
vendored
@@ -279,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
|
||||
// It returns an error if either the name or digest is invalid, or if link
|
||||
// creation encounters any issues.
|
||||
func (c *DiskCache) Link(name string, d Digest) error {
|
||||
// TODO(bmizerany): Move link handling from cache to registry.
|
||||
//
|
||||
// We originally placed links in the cache due to its storage
|
||||
// knowledge. However, the registry likely offers better context for
|
||||
// naming concerns, and our API design shouldn't be tightly coupled to
|
||||
// our on-disk format.
|
||||
//
|
||||
// Links work effectively when independent from physical location -
|
||||
// they can reference content with matching SHA regardless of storage
|
||||
// location. In an upcoming change, we plan to shift this
|
||||
// responsibility to the registry where it better aligns with the
|
||||
// system's conceptual model.
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -304,21 +316,19 @@ func (c *DiskCache) Link(name string, d Digest) error {
|
||||
return c.copyNamedFile(manifest, f, d, info.Size())
|
||||
}
|
||||
|
||||
// Unlink removes the any link for name. If the link does not exist, nothing
|
||||
// happens, and no error is returned.
|
||||
//
|
||||
// It returns an error if the name is invalid or if the link removal encounters
|
||||
// any issues.
|
||||
func (c *DiskCache) Unlink(name string) error {
|
||||
// Unlink unlinks the manifest by name from the cache. If the name is not
|
||||
// found. If a manifest is removed ok will be true, otherwise false. If an
|
||||
// error occurs, it returns ok false, and the error.
|
||||
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
|
||||
manifest, err := c.manifestPath(name)
|
||||
if err != nil {
|
||||
return err
|
||||
return false, err
|
||||
}
|
||||
err = os.Remove(manifest)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil
|
||||
return false, nil
|
||||
}
|
||||
return err
|
||||
return true, err
|
||||
}
|
||||
|
||||
// GetFile returns the absolute path to the file, in the cache, for the given
|
||||
|
||||
7
server/internal/cache/blob/cache_test.go
vendored
7
server/internal/cache/blob/cache_test.go
vendored
@@ -13,7 +13,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/internal/testutil"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -479,8 +479,11 @@ func testManifestNameReuse(t *testing.T) {
|
||||
}
|
||||
|
||||
// relink with different case
|
||||
err = c.Unlink("h/n/m:t")
|
||||
unlinked, err := c.Unlink("h/n/m:t")
|
||||
check(err)
|
||||
if !unlinked {
|
||||
t.Fatal("expected unlinked")
|
||||
}
|
||||
err = c.Link("h/n/m:T", d1)
|
||||
check(err)
|
||||
|
||||
|
||||
2
server/internal/cache/blob/casecheck_test.go
vendored
2
server/internal/cache/blob/casecheck_test.go
vendored
@@ -86,7 +86,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
|
||||
// link to docs on that topic.
|
||||
lines := strings.Split(volumeHint, "\n")
|
||||
for _, line := range lines {
|
||||
t.Log(line)
|
||||
t.Skip(line)
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
@@ -19,10 +19,12 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -52,7 +54,7 @@ var (
|
||||
|
||||
// ErrMissingModel is returned when the model part of a name is missing
|
||||
// or invalid.
|
||||
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
|
||||
ErrNameInvalid = errors.New("invalid or missing name")
|
||||
|
||||
// ErrCached is passed to [Trace.PushUpdate] when a layer already
|
||||
// exists. It is a non-fatal error and is never returned by [Registry.Push].
|
||||
@@ -86,9 +88,23 @@ func DefaultCache() (*blob.DiskCache, error) {
|
||||
return blob.Open(dir)
|
||||
}
|
||||
|
||||
// Error is the standard error returned by Ollama APIs.
|
||||
// Error is the standard error returned by Ollama APIs. It can represent a
|
||||
// single or multiple error response.
|
||||
//
|
||||
// Single error responses have the following format:
|
||||
//
|
||||
// {"code": "optional_code","error":"error message"}
|
||||
//
|
||||
// Multiple error responses have the following format:
|
||||
//
|
||||
// {"errors": [{"code": "optional_code","message":"error message"}]}
|
||||
//
|
||||
// Note, that the error field is used in single error responses, while the
|
||||
// message field is used in multiple error responses.
|
||||
//
|
||||
// In both cases, the code field is optional and may be empty.
|
||||
type Error struct {
|
||||
Status int `json:"-"`
|
||||
Status int `json:"-"` // TODO(bmizerany): remove this
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
@@ -97,13 +113,34 @@ func (e *Error) Error() string {
|
||||
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
|
||||
}
|
||||
|
||||
func (e *Error) LogValue() slog.Value {
|
||||
return slog.GroupValue(
|
||||
slog.Int("status", e.Status),
|
||||
slog.String("code", e.Code),
|
||||
slog.String("message", e.Message),
|
||||
)
|
||||
}
|
||||
|
||||
// UnmarshalJSON implements json.Unmarshaler.
|
||||
func (e *Error) UnmarshalJSON(b []byte) error {
|
||||
type E Error
|
||||
var v struct{ Errors []E }
|
||||
var v struct {
|
||||
// Single error
|
||||
Code string
|
||||
Error string
|
||||
|
||||
// Multiple errors
|
||||
Errors []E
|
||||
}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
if v.Error != "" {
|
||||
// Single error case
|
||||
e.Code = v.Code
|
||||
e.Message = v.Error
|
||||
return nil
|
||||
}
|
||||
if len(v.Errors) == 0 {
|
||||
return fmt.Errorf("no messages in error response: %s", string(b))
|
||||
}
|
||||
@@ -111,15 +148,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO(bmizerany): make configurable on [Registry]
|
||||
var defaultName = func() names.Name {
|
||||
n := names.Parse("ollama.com/library/_:latest")
|
||||
const DefaultMask = "registry.ollama.ai/library/_:latest"
|
||||
|
||||
var defaultMask = func() names.Name {
|
||||
n := names.Parse(DefaultMask)
|
||||
if !n.IsFullyQualified() {
|
||||
panic("default name is not fully qualified")
|
||||
panic("default mask is not fully qualified")
|
||||
}
|
||||
return n
|
||||
}()
|
||||
|
||||
// CompleteName returns a fully qualified name by merging the given name with
|
||||
// the default mask. If the name is already fully qualified, it is returned
|
||||
// unchanged.
|
||||
func CompleteName(name string) string {
|
||||
return names.Merge(names.Parse(name), defaultMask).String()
|
||||
}
|
||||
|
||||
// Registry is a client for performing push and pull operations against an
|
||||
// Ollama registry.
|
||||
type Registry struct {
|
||||
@@ -160,21 +205,38 @@ type Registry struct {
|
||||
//
|
||||
// It is only used when a layer is larger than [MaxChunkingThreshold].
|
||||
MaxChunkSize int64
|
||||
|
||||
// Mask, if set, is the name used to convert non-fully qualified
|
||||
// names to fully qualified names. If empty, the default mask
|
||||
// ("registry.ollama.ai/library/_:latest") is used.
|
||||
Mask string
|
||||
}
|
||||
|
||||
// RegistryFromEnv returns a new Registry configured from the environment. The
|
||||
func (r *Registry) parseName(name string) (names.Name, error) {
|
||||
mask := defaultMask
|
||||
if r.Mask != "" {
|
||||
mask = names.Parse(r.Mask)
|
||||
}
|
||||
n := names.Merge(names.Parse(name), mask)
|
||||
if !n.IsFullyQualified() {
|
||||
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
|
||||
}
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// DefaultRegistry returns a new Registry configured from the environment. The
|
||||
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
|
||||
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
|
||||
// system's temporary directory.
|
||||
//
|
||||
// It returns an error if any configuration in the environment is invalid.
|
||||
func RegistryFromEnv() (*Registry, error) {
|
||||
func DefaultRegistry() (*Registry, error) {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
|
||||
if err != nil {
|
||||
if err != nil && errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -194,42 +256,6 @@ func RegistryFromEnv() (*Registry, error) {
|
||||
return &rc, nil
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// parseName parses name using [names.ParseExtended] and then merges the name with the
|
||||
// default name, and checks that the name is fully qualified. If a digest is
|
||||
// present, it parse and returns it with the other fields as their zero values.
|
||||
//
|
||||
// It returns an error if the name is not fully qualified, or if the digest, if
|
||||
// any, is invalid.
|
||||
//
|
||||
// The scheme is returned as provided by [names.ParseExtended].
|
||||
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
|
||||
scheme, n, ds := names.ParseExtended(s)
|
||||
n = names.Merge(n, defaultName)
|
||||
if ds != "" {
|
||||
// Digest is present. Validate it.
|
||||
d, err = blob.ParseDigest(ds)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
}
|
||||
|
||||
// The name check is deferred until after the digest check because we
|
||||
// say that digests take precedence over names, and so should there
|
||||
// errors when being parsed.
|
||||
if !n.IsFullyQualified() {
|
||||
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
|
||||
}
|
||||
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
func (r *Registry) maxStreams() int {
|
||||
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
|
||||
|
||||
@@ -249,13 +275,19 @@ func (r *Registry) maxChunkSize() int64 {
|
||||
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
// From is an optional destination name for the model. If empty, the
|
||||
// destination name is the same as the source name.
|
||||
From string
|
||||
}
|
||||
|
||||
// Push pushes the model with the name in the cache to the remote registry.
|
||||
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
|
||||
if p == nil {
|
||||
p = &PushParams{}
|
||||
}
|
||||
|
||||
m, err := ResolveLocal(c, cmp.Or(p.From, name))
|
||||
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -278,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
|
||||
|
||||
t := traceFromContext(ctx)
|
||||
|
||||
scheme, n, _, err := parseName(name)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
// This should never happen since ResolveLocal should have
|
||||
// already validated the name.
|
||||
@@ -372,7 +404,7 @@ func canRetry(err error) bool {
|
||||
// typically slower than splitting the model up across layers, and is mostly
|
||||
// utilized for layers of type equal to "application/vnd.ollama.image".
|
||||
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
|
||||
scheme, n, _, err := parseName(name)
|
||||
scheme, n, _, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -520,6 +552,16 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
|
||||
return c.Link(m.Name, md)
|
||||
}
|
||||
|
||||
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
|
||||
// before attempting to unlink the model.
|
||||
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return c.Unlink(n.String())
|
||||
}
|
||||
|
||||
// Manifest represents a [ollama.com/manifest].
|
||||
type Manifest struct {
|
||||
Name string `json:"-"` // the canonical name of the model
|
||||
@@ -588,10 +630,9 @@ type Layer struct {
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
|
||||
// parsed using [names.ParseExtended] but the scheme is ignored.
|
||||
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := parseName(name)
|
||||
// ResolveLocal resolves a name to a Manifest in the local cache.
|
||||
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
_, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -617,7 +658,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
|
||||
|
||||
// Resolve resolves a name to a Manifest in the remote registry.
|
||||
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
|
||||
scheme, n, d, err := parseName(name)
|
||||
scheme, n, d, err := r.parseNameExtended(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -800,3 +841,89 @@ func maybeUnexpectedEOF(err error) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type publicError struct {
|
||||
wrapped error
|
||||
message string
|
||||
}
|
||||
|
||||
func withPublicMessagef(err error, message string, args ...any) error {
|
||||
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
|
||||
}
|
||||
|
||||
func (e publicError) Error() string { return e.message }
|
||||
func (e publicError) Unwrap() error { return e.wrapped }
|
||||
|
||||
var supportedSchemes = []string{
|
||||
"http",
|
||||
"https",
|
||||
"https+insecure",
|
||||
}
|
||||
|
||||
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
|
||||
|
||||
// parseNameExtended parses and validates an extended name, returning the scheme, name,
|
||||
// and digest.
|
||||
//
|
||||
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
|
||||
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
//
|
||||
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
|
||||
// message is returned.
|
||||
//
|
||||
// If the name is not, once merged with the mask, fully qualified,
|
||||
// [ErrNameInvalid] wrapped with a display friendly message is returned.
|
||||
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
|
||||
scheme, name, digest := splitExtended(s)
|
||||
scheme = cmp.Or(scheme, "https")
|
||||
if !slices.Contains(supportedSchemes, scheme) {
|
||||
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
|
||||
var d blob.Digest
|
||||
if digest != "" {
|
||||
var err error
|
||||
d, err = blob.ParseDigest(digest)
|
||||
if err != nil {
|
||||
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
if name == "" {
|
||||
// We have can resolve a manifest from a digest only,
|
||||
// so skip name validation and return the scheme and
|
||||
// digest.
|
||||
return scheme, names.Name{}, d, nil
|
||||
}
|
||||
}
|
||||
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
return "", names.Name{}, blob.Digest{}, err
|
||||
}
|
||||
return scheme, n, d, nil
|
||||
}
|
||||
|
||||
// splitExtended splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
// model@digest
|
||||
// @digest
|
||||
func splitExtended(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
s = s[i+3:]
|
||||
}
|
||||
i = strings.LastIndex(s, "@")
|
||||
if i >= 0 {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@ package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
@@ -21,7 +22,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/chunks"
|
||||
"github.com/ollama/ollama/server/internal/internal/testutil"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
func TestManifestMarshalJSON(t *testing.T) {
|
||||
@@ -37,20 +38,6 @@ func TestManifestMarshalJSON(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func link(c *blob.DiskCache, name string, manifest string) {
|
||||
_, n, _, err := parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
var errRoundTrip = errors.New("forced roundtrip error")
|
||||
|
||||
type recordRoundTripper http.HandlerFunc
|
||||
@@ -98,30 +85,45 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
|
||||
}
|
||||
}
|
||||
|
||||
r := &Registry{
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
|
||||
link := func(name string, manifest string) {
|
||||
n, err := r.parseName(name)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := c.Link(n.String(), d); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
commit := func(name string, layers ...*Layer) {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(&Manifest{Layers: layers})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
link(c, name, string(data))
|
||||
link(name, string(data))
|
||||
}
|
||||
|
||||
link(c, "empty", "")
|
||||
link("empty", "")
|
||||
commit("zero")
|
||||
commit("single", mklayer("exists"))
|
||||
commit("multiple", mklayer("exists"), mklayer("present"))
|
||||
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
|
||||
commit("null", nil)
|
||||
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
|
||||
link(c, "invalid", "!!!!!")
|
||||
link("invalid", "!!!!!")
|
||||
|
||||
rc := &Registry{
|
||||
HTTPClient: &http.Client{
|
||||
Transport: recordRoundTripper(h),
|
||||
},
|
||||
}
|
||||
return rc, c
|
||||
return r, c
|
||||
}
|
||||
|
||||
func okHandler(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -144,29 +146,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
|
||||
return d
|
||||
}
|
||||
|
||||
func TestRegistryPushInvalidNames(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
}{
|
||||
{"", ErrNameInvalid},
|
||||
{"@", ErrNameInvalid},
|
||||
{"@x", blob.ErrInvalidDigest},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create a new registry and push a new image.
|
||||
err := rc.Push(t.Context(), c, tt.name, nil)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("err = %v; want %v", err, tt.err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
|
||||
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
|
||||
return WithTrace(ctx, t), t
|
||||
@@ -385,7 +364,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
})
|
||||
|
||||
// Confirm that the layer does not exist locally
|
||||
_, err := ResolveLocal(c, "model")
|
||||
_, err := rc.ResolveLocal(c, "model")
|
||||
checkNotExist(t, err)
|
||||
|
||||
_, err = c.Get(d)
|
||||
@@ -396,7 +375,7 @@ func TestRegistryPullNotCached(t *testing.T) {
|
||||
|
||||
mw, err := rc.Resolve(t.Context(), "model")
|
||||
check(err)
|
||||
mg, err := ResolveLocal(c, "model")
|
||||
mg, err := rc.ResolveLocal(c, "model")
|
||||
check(err)
|
||||
if !reflect.DeepEqual(mw, mg) {
|
||||
t.Errorf("mw = %v; mg = %v", mw, mg)
|
||||
@@ -621,7 +600,7 @@ func TestInsecureSkipVerify(t *testing.T) {
|
||||
}))
|
||||
defer s.Close()
|
||||
|
||||
const name = "ollama.com/library/insecure"
|
||||
const name = "library/insecure"
|
||||
|
||||
var rc Registry
|
||||
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
|
||||
@@ -654,3 +633,184 @@ func TestCanRetry(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestErrorUnmarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
want *Error
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors empty",
|
||||
data: `{"errors":[]}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "errors single",
|
||||
data: `{"errors":[{"code":"blob_unknown"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "errors multiple",
|
||||
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
|
||||
want: &Error{Code: "blob_unknown", Message: ""},
|
||||
},
|
||||
{
|
||||
name: "error empty",
|
||||
data: `{"error":""}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error very empty",
|
||||
data: `{}`,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "error message",
|
||||
data: `{"error":"message", "code":"code"}`,
|
||||
want: &Error{Code: "code", Message: "message"},
|
||||
},
|
||||
{
|
||||
name: "invalid value",
|
||||
data: `{"error": 1}`,
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got Error
|
||||
err := json.Unmarshal([]byte(tt.data), &got)
|
||||
if err != nil {
|
||||
if tt.wantErr {
|
||||
return
|
||||
}
|
||||
t.Errorf("Unmarshal() error = %v", err)
|
||||
// fallthrough and check got
|
||||
}
|
||||
if tt.want == nil {
|
||||
tt.want = &Error{}
|
||||
}
|
||||
if !reflect.DeepEqual(got, *tt.want) {
|
||||
t.Errorf("got = %v; want %v", got, *tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestParseNameErrors tests that parseName returns errors messages with enough
|
||||
// detail for users to debug naming issues they may encounter. Previous to this
|
||||
// test, the error messages were not very helpful and each problem was reported
|
||||
// as the same message.
|
||||
//
|
||||
// It is only for testing error messages, not that all invalids and valids are
|
||||
// covered. Those are in other tests for names.Name and blob.Digest.
|
||||
func TestParseNameExtendedErrors(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
err error
|
||||
want string
|
||||
}{}
|
||||
|
||||
var r Registry
|
||||
for _, tt := range cases {
|
||||
_, _, _, err := r.parseNameExtended(tt.name)
|
||||
if !errors.Is(err, tt.err) {
|
||||
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
|
||||
}
|
||||
if err != nil && !strings.Contains(err.Error(), tt.want) {
|
||||
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameExtended(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
scheme string
|
||||
name string
|
||||
digest string
|
||||
err string
|
||||
}{
|
||||
{in: "http://m", scheme: "http", name: "m"},
|
||||
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
|
||||
{in: "http+insecure://m", err: "unsupported scheme"},
|
||||
|
||||
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
|
||||
|
||||
{in: "", err: "invalid or missing name"},
|
||||
{in: "m", scheme: "https", name: "m"},
|
||||
{in: "://", err: "invalid or missing name"},
|
||||
{in: "@sha256:deadbeef", err: "invalid digest"},
|
||||
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
var r Registry
|
||||
scheme, n, digest, err := r.parseNameExtended(tt.in)
|
||||
if err != nil {
|
||||
if tt.err == "" {
|
||||
t.Errorf("err = %v; want nil", err)
|
||||
} else if !strings.Contains(err.Error(), tt.err) {
|
||||
t.Errorf("err = %v; want %q", err, tt.err)
|
||||
}
|
||||
} else if tt.err != "" {
|
||||
t.Errorf("err = nil; want %q", tt.err)
|
||||
}
|
||||
if err == nil && !n.IsFullyQualified() {
|
||||
t.Errorf("name = %q; want fully qualified", n)
|
||||
}
|
||||
|
||||
if scheme != tt.scheme {
|
||||
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
|
||||
}
|
||||
|
||||
// smoke-test name is superset of tt.name
|
||||
if !strings.Contains(n.String(), tt.name) {
|
||||
t.Errorf("name = %q; want %q", n, tt.name)
|
||||
}
|
||||
|
||||
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
|
||||
if digest.String() != tt.digest {
|
||||
t.Errorf("digest = %q; want %q", digest, tt.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUnlink(t *testing.T) {
|
||||
t.Run("found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
|
||||
// confirm linked
|
||||
_, err := rc.ResolveLocal(c, "single")
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
// unlink
|
||||
_, err = rc.Unlink(c, "single")
|
||||
testutil.Check(t, err)
|
||||
|
||||
// confirm unlinked
|
||||
_, err = rc.ResolveLocal(c, "single")
|
||||
if !errors.Is(err, fs.ErrNotExist) {
|
||||
t.Errorf("err = %v; want fs.ErrNotExist", err)
|
||||
}
|
||||
})
|
||||
t.Run("not found by name", func(t *testing.T) {
|
||||
rc, c := newClient(t, nil)
|
||||
ok, err := rc.Unlink(c, "manifestNotFound")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if ok {
|
||||
t.Error("expected not found")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
@@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
|
||||
|
||||
// TODO(bmizerany): do something with metadata? This could be another
|
||||
// header read if needed. We also need to figure out if the metadata is
|
||||
// present in only one .safetensors file or if each file may have their
|
||||
@@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
tt := make([]*Tensor, 0, len(raws))
|
||||
for name, raw := range raws {
|
||||
if !strings.HasPrefix(name, "model.layer") {
|
||||
if name == "__metadata__" {
|
||||
// TODO(bmizerany): do something with metadata?
|
||||
continue
|
||||
}
|
||||
var v struct {
|
||||
@@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
|
||||
|
||||
// TODO(bmizerany): after collecting, validate all offests make
|
||||
// tensors contiguous?
|
||||
begin, end := v.Offsets[0], v.Offsets[1]
|
||||
begin := endOfHeader + v.Offsets[0]
|
||||
end := endOfHeader + v.Offsets[1]
|
||||
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -68,7 +68,7 @@ func main() {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
rc, err := ollama.RegistryFromEnv()
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@@ -177,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
|
||||
}
|
||||
|
||||
from := cmp.Or(*flagFrom, model)
|
||||
m, err := ollama.ResolveLocal(c, from)
|
||||
m, err := rc.ResolveLocal(c, from)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
flag.PrintDefaults()
|
||||
}
|
||||
flag.Parse(args)
|
||||
if *flagAs == "" {
|
||||
return fmt.Errorf("missing -as flag")
|
||||
}
|
||||
as := ollama.CompleteName(*flagAs)
|
||||
|
||||
dir := cmp.Or(flag.Arg(0), ".")
|
||||
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
|
||||
@@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.Link(*flagAs, d)
|
||||
return c.Link(as, d)
|
||||
}()
|
||||
}()
|
||||
|
||||
@@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
|
||||
writeProgress()
|
||||
case err := <-done:
|
||||
writeProgress()
|
||||
fmt.Println()
|
||||
fmt.Println("Successfully imported", as)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import (
|
||||
"github.com/ollama/ollama/server/internal/internal/stringsx"
|
||||
)
|
||||
|
||||
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
|
||||
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
|
||||
|
||||
type Name struct {
|
||||
// Make incomparable to enfoce use of Compare / Equal for
|
||||
@@ -25,19 +25,12 @@ type Name struct {
|
||||
// format of a valid name string is:
|
||||
//
|
||||
// s:
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model } ":" { tag }
|
||||
// { host } "/" { namespace } "/" { model } "@" { digest }
|
||||
// { host } "/" { namespace } "/" { model }
|
||||
// { namespace } "/" { model } ":" { tag } "@" { digest }
|
||||
// { namespace } "/" { model } ":" { tag }
|
||||
// { namespace } "/" { model } "@" { digest }
|
||||
// { namespace } "/" { model }
|
||||
// { model } ":" { tag } "@" { digest }
|
||||
// { model } ":" { tag }
|
||||
// { model } "@" { digest }
|
||||
// { model }
|
||||
// "@" { digest }
|
||||
// host:
|
||||
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
|
||||
// length: [1, 350]
|
||||
@@ -50,9 +43,6 @@ type Name struct {
|
||||
// tag:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
|
||||
// length: [1, 80]
|
||||
// digest:
|
||||
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
|
||||
// length: [1, 80]
|
||||
//
|
||||
// The name returned is not guaranteed to be valid. If it is not valid, the
|
||||
// field values are left in an undefined state. Use [Name.IsValid] to check
|
||||
@@ -82,23 +72,17 @@ func Parse(s string) Name {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseExtended parses and returns any scheme, Name, and digest from from s in
|
||||
// the the form [scheme://][name][@digest]. All parts are optional.
|
||||
//
|
||||
// If the scheme is present, it must be followed by "://". The digest is
|
||||
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
|
||||
//
|
||||
// The scheme and digest are stripped before the name is parsed by [Parse].
|
||||
//
|
||||
// For convience, the scheme is never empty. If the scheme is not present, the
|
||||
// returned scheme is "https".
|
||||
// Split splits an extended name string into its scheme, name, and digest
|
||||
// parts.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// http://ollama.com/bmizerany/smol:latest@digest
|
||||
// https://ollama.com/bmizerany/smol:latest
|
||||
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
|
||||
func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
// model@digest
|
||||
// @digest
|
||||
func Split(s string) (scheme, name, digest string) {
|
||||
i := strings.Index(s, "://")
|
||||
if i >= 0 {
|
||||
scheme = s[:i]
|
||||
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
|
||||
digest = s[i+1:]
|
||||
s = s[:i]
|
||||
}
|
||||
return scheme, Parse(s), digest
|
||||
}
|
||||
|
||||
func FormatExtended(scheme string, n Name, digest string) string {
|
||||
var b strings.Builder
|
||||
if scheme != "" {
|
||||
b.WriteString(scheme)
|
||||
b.WriteString("://")
|
||||
}
|
||||
b.WriteString(n.String())
|
||||
if digest != "" {
|
||||
b.WriteByte('@')
|
||||
b.WriteString(digest)
|
||||
}
|
||||
return b.String()
|
||||
return scheme, s, digest
|
||||
}
|
||||
|
||||
// Merge merges two names into a single name. Non-empty host, namespace, and
|
||||
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
|
||||
|
||||
// IsValid returns true if the name is valid.
|
||||
func (n Name) IsValid() bool {
|
||||
if n.h != "" && !isValidHost(n.h) {
|
||||
if n.h != "" && !isValidPart(partHost, n.h) {
|
||||
return false
|
||||
}
|
||||
if n.n != "" && !isValidNamespace(n.n) {
|
||||
if n.n != "" && !isValidPart(partNamespace, n.n) {
|
||||
return false
|
||||
}
|
||||
if n.m != "" && !isValidModel(n.m) {
|
||||
if n.t != "" && !isValidPart(partTag, n.t) {
|
||||
return false
|
||||
}
|
||||
if n.t != "" && !isValidTag(n.t) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
|
||||
// at bare minimum, model must be present and valid
|
||||
return n.m != "" && isValidPart(partModel, n.m)
|
||||
}
|
||||
|
||||
func (n Name) IsFullyQualified() bool {
|
||||
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
|
||||
}
|
||||
|
||||
func isValidHost(_ string) bool {
|
||||
return true // TODO: implement
|
||||
const (
|
||||
partHost = iota
|
||||
partNamespace
|
||||
partModel
|
||||
partTag
|
||||
)
|
||||
|
||||
func isValidPart(kind int, s string) bool {
|
||||
maxlen := 80
|
||||
if kind == partHost {
|
||||
maxlen = 350
|
||||
}
|
||||
if len(s) > maxlen {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := range s {
|
||||
if i == 0 {
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
continue
|
||||
}
|
||||
switch s[i] {
|
||||
case '_', '-':
|
||||
case '.':
|
||||
if kind == partNamespace {
|
||||
return false
|
||||
}
|
||||
case ':':
|
||||
if kind != partHost {
|
||||
return false
|
||||
}
|
||||
default:
|
||||
if !isAlphanumericOrUnderscore(s[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidNamespace(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidModel(_ string) bool {
|
||||
return true // TODO: implement
|
||||
}
|
||||
|
||||
func isValidTag(_ string) bool {
|
||||
return true // TODO: implement
|
||||
func isAlphanumericOrUnderscore(c byte) bool {
|
||||
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
|
||||
}
|
||||
|
||||
func (n Name) Host() string { return n.h }
|
||||
|
||||
@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
scheme, name, digest := ParseExtended(tt.in)
|
||||
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
scheme, name, digest := Split(tt.in)
|
||||
n := Parse(name)
|
||||
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
|
||||
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
|
||||
}
|
||||
|
||||
// Round trip
|
||||
if got := FormatExtended(scheme, name, digest); got != tt.in {
|
||||
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
|
||||
junkName = Parse("h/n/m:t")
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
|
||||
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
|
||||
)
|
||||
|
||||
var testCases = map[string]bool{ // name -> valid
|
||||
"": false,
|
||||
|
||||
"_why/_the/_lucky:_stiff": true,
|
||||
|
||||
// minimal
|
||||
"h/n/m:t": true,
|
||||
|
||||
"host/namespace/model:tag": true,
|
||||
"host/namespace/model": true,
|
||||
"namespace/model": true,
|
||||
"model": true,
|
||||
|
||||
// long (but valid)
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
|
||||
|
||||
// too long
|
||||
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
|
||||
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
|
||||
|
||||
"h/nn/mm:t": true, // bare minimum part sizes
|
||||
|
||||
// unqualified
|
||||
"m": true,
|
||||
"n/m:": true,
|
||||
"h/n/m": true,
|
||||
"@t": false,
|
||||
"m@d": false,
|
||||
|
||||
// invalids
|
||||
"^": false,
|
||||
"mm:": true,
|
||||
"/nn/mm": true,
|
||||
"//": false, // empty model
|
||||
"//mm": true,
|
||||
"hh//": false, // empty model
|
||||
"//mm:@": false,
|
||||
"00@": false,
|
||||
"@": false,
|
||||
|
||||
// not starting with alphanum
|
||||
"-hh/nn/mm:tt": false,
|
||||
"hh/-nn/mm:tt": false,
|
||||
"hh/nn/-mm:tt": false,
|
||||
"hh/nn/mm:-tt": false,
|
||||
|
||||
// smells like a flag
|
||||
"-h": false,
|
||||
|
||||
// hosts
|
||||
"host:https/namespace/model:tag": true,
|
||||
|
||||
// colon in non-host part before tag
|
||||
"host/name:space/model:tag": false,
|
||||
}
|
||||
|
||||
func TestParseNameValidation(t *testing.T) {
|
||||
for s, valid := range testCases {
|
||||
got := Parse(s)
|
||||
if got.IsValid() != valid {
|
||||
t.Logf("got: %v", got)
|
||||
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
235
server/internal/registry/server.go
Normal file
235
server/internal/registry/server.go
Normal file
@@ -0,0 +1,235 @@
|
||||
// Package registry provides an http.Handler for handling local Ollama API
|
||||
// requests for performing tasks related to the ollama.com model registry and
|
||||
// the local disk cache.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
)
|
||||
|
||||
// Local is an http.Handler for handling local Ollama API requests for
|
||||
// performing tasks related to the ollama.com model registry combined with the
|
||||
// local disk cache.
|
||||
//
|
||||
// It is not concern of Local, or this package, to handle model creation, which
|
||||
// proceeds any registry operations for models it produces.
|
||||
//
|
||||
// NOTE: The package built for dealing with model creation should use
|
||||
// [DefaultCache] to access the blob store and not attempt to read or write
|
||||
// directly to the blob disk cache.
|
||||
type Local struct {
|
||||
Client *ollama.Registry // required
|
||||
Cache *blob.DiskCache // required
|
||||
Logger *slog.Logger // required
|
||||
|
||||
// Fallback, if set, is used to handle requests that are not handled by
|
||||
// this handler.
|
||||
Fallback http.Handler
|
||||
}
|
||||
|
||||
// serverError is like ollama.Error, but with a Status field for the HTTP
|
||||
// response code. We want to avoid adding that field to ollama.Error because it
|
||||
// would always be 0 to clients (we don't want to leak the status code in
|
||||
// errors), and so it would be confusing to have a field that is always 0.
|
||||
type serverError struct {
|
||||
Status int `json:"-"`
|
||||
|
||||
// TODO(bmizerany): Decide if we want to keep this and maybe
|
||||
// bring back later.
|
||||
Code string `json:"code"`
|
||||
|
||||
Message string `json:"error"`
|
||||
}
|
||||
|
||||
func (e serverError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
// Common API errors
|
||||
var (
|
||||
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
|
||||
errNotFound = &serverError{404, "not_found", "not found"}
|
||||
errInternalError = &serverError{500, "internal_error", "internal server error"}
|
||||
)
|
||||
|
||||
type statusCodeRecorder struct {
|
||||
_status int // use status() to get the status code
|
||||
http.ResponseWriter
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) WriteHeader(status int) {
|
||||
if r._status == 0 {
|
||||
r._status = status
|
||||
}
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
|
||||
var (
|
||||
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
|
||||
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
|
||||
_ http.Flusher = (*statusCodeRecorder)(nil)
|
||||
)
|
||||
|
||||
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a CloseNotifier.
|
||||
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
|
||||
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
|
||||
}
|
||||
|
||||
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
|
||||
//
|
||||
// It panics if the underlying ResponseWriter is not a Flusher.
|
||||
func (r *statusCodeRecorder) Flush() {
|
||||
r.ResponseWriter.(http.Flusher).Flush()
|
||||
}
|
||||
|
||||
func (r *statusCodeRecorder) status() int {
|
||||
return cmp.Or(r._status, 200)
|
||||
}
|
||||
|
||||
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
rec := &statusCodeRecorder{ResponseWriter: w}
|
||||
s.serveHTTP(rec, r)
|
||||
}
|
||||
|
||||
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
|
||||
var errattr slog.Attr
|
||||
proxied, err := func() (bool, error) {
|
||||
switch r.URL.Path {
|
||||
case "/api/delete":
|
||||
return false, s.handleDelete(rec, r)
|
||||
default:
|
||||
if s.Fallback != nil {
|
||||
s.Fallback.ServeHTTP(rec, r)
|
||||
return true, nil
|
||||
}
|
||||
return false, errNotFound
|
||||
}
|
||||
}()
|
||||
if err != nil {
|
||||
// We always log the error, so fill in the error log attribute
|
||||
errattr = slog.String("error", err.Error())
|
||||
|
||||
var e *serverError
|
||||
switch {
|
||||
case errors.As(err, &e):
|
||||
case errors.Is(err, ollama.ErrNameInvalid):
|
||||
e = &serverError{400, "bad_request", err.Error()}
|
||||
default:
|
||||
e = errInternalError
|
||||
}
|
||||
|
||||
data, err := json.Marshal(e)
|
||||
if err != nil {
|
||||
// unreachable
|
||||
panic(err)
|
||||
}
|
||||
rec.Header().Set("Content-Type", "application/json")
|
||||
rec.WriteHeader(e.Status)
|
||||
rec.Write(data)
|
||||
|
||||
// fallthrough to log
|
||||
}
|
||||
|
||||
if !proxied {
|
||||
// we're only responsible for logging if we handled the request
|
||||
var level slog.Level
|
||||
if rec.status() >= 500 {
|
||||
level = slog.LevelError
|
||||
} else if rec.status() >= 400 {
|
||||
level = slog.LevelWarn
|
||||
}
|
||||
|
||||
s.Logger.LogAttrs(r.Context(), level, "http",
|
||||
errattr, // report first in line to make it easy to find
|
||||
|
||||
// TODO(bmizerany): Write a test to ensure that we are logging
|
||||
// all of this correctly. That also goes for the level+error
|
||||
// logic above.
|
||||
slog.Int("status", rec.status()),
|
||||
slog.String("method", r.Method),
|
||||
slog.String("path", r.URL.Path),
|
||||
slog.Int64("content-length", r.ContentLength),
|
||||
slog.String("remote", r.RemoteAddr),
|
||||
slog.String("proto", r.Proto),
|
||||
slog.String("query", r.URL.RawQuery),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type params struct {
|
||||
DeprecatedName string `json:"name"` // Use [params.model]
|
||||
Model string `json:"model"` // Use [params.model]
|
||||
|
||||
// AllowNonTLS is a flag that indicates a client using HTTP
|
||||
// is doing so, deliberately.
|
||||
//
|
||||
// Deprecated: This field is ignored and only present for this
|
||||
// deprecation message. It should be removed in a future release.
|
||||
//
|
||||
// Users can just use http or https+insecure to show intent to
|
||||
// communicate they want to do insecure things, without awkward and
|
||||
// confusing flags such as this.
|
||||
AllowNonTLS bool `json:"insecure"`
|
||||
|
||||
// ProgressStream is a flag that indicates the client is expecting a stream of
|
||||
// progress updates.
|
||||
ProgressStream bool `json:"stream"`
|
||||
}
|
||||
|
||||
// model returns the model name for both old and new API requests.
|
||||
func (p params) model() string {
|
||||
return cmp.Or(p.Model, p.DeprecatedName)
|
||||
}
|
||||
|
||||
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "DELETE" {
|
||||
return errMethodNotAllowed
|
||||
}
|
||||
p, err := decodeUserJSON[*params](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ok, err := s.Client.Unlink(s.Cache, p.model())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !ok {
|
||||
return &serverError{404, "not_found", "model not found"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeUserJSON[T any](r io.Reader) (T, error) {
|
||||
var v T
|
||||
err := json.NewDecoder(r).Decode(&v)
|
||||
if err == nil {
|
||||
return v, nil
|
||||
}
|
||||
var zero T
|
||||
|
||||
// Not sure why, but I can't seem to be able to use:
|
||||
//
|
||||
// errors.As(err, &json.UnmarshalTypeError{})
|
||||
//
|
||||
// This is working fine in stdlib, so I'm not sure what rules changed
|
||||
// and why this no longer works here. So, we do it the verbose way.
|
||||
var a *json.UnmarshalTypeError
|
||||
var b *json.SyntaxError
|
||||
if errors.As(err, &a) || errors.As(err, &b) {
|
||||
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
|
||||
}
|
||||
return zero, err
|
||||
}
|
||||
165
server/internal/registry/server_test.go
Normal file
165
server/internal/registry/server_test.go
Normal file
@@ -0,0 +1,165 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/testutil"
|
||||
)
|
||||
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
panic("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
// bytesResetter is an interface for types that can be reset and return a byte
|
||||
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
|
||||
// etc for the purpose of checking logs.
|
||||
type bytesResetter interface {
|
||||
Bytes() []byte
|
||||
Reset()
|
||||
}
|
||||
|
||||
func newTestServer(t *testing.T) *Local {
|
||||
t.Helper()
|
||||
dir := t.TempDir()
|
||||
err := os.CopyFS(dir, os.DirFS("testdata/models"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
c, err := blob.Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rc := &ollama.Registry{
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
l := &Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: testutil.Slogger(t),
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
|
||||
return s.sendRequest(t, req)
|
||||
}
|
||||
|
||||
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
|
||||
t.Helper()
|
||||
w := httptest.NewRecorder()
|
||||
s.ServeHTTP(w, req)
|
||||
return w
|
||||
}
|
||||
|
||||
type invalidReader struct{}
|
||||
|
||||
func (r *invalidReader) Read(p []byte) (int, error) {
|
||||
return 0, os.ErrInvalid
|
||||
}
|
||||
|
||||
// captureLogs is a helper to capture logs from the server. It returns a
|
||||
// shallow copy of the server with a new logger and a bytesResetter for the
|
||||
// logs.
|
||||
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
|
||||
t.Helper()
|
||||
log, logs := testutil.SlogBuffer()
|
||||
l := *s // shallow copy
|
||||
l.Logger = log
|
||||
return &l, logs
|
||||
}
|
||||
|
||||
func TestServerDelete(t *testing.T) {
|
||||
check := testutil.Checker(t)
|
||||
|
||||
s := newTestServer(t)
|
||||
|
||||
_, err := s.Client.ResolveLocal(s.Cache, "smol")
|
||||
check(err)
|
||||
|
||||
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
|
||||
if got.Code != 200 {
|
||||
t.Fatalf("Code = %d; want 200", got.Code)
|
||||
}
|
||||
|
||||
_, err = s.Client.ResolveLocal(s.Cache, "smol")
|
||||
if err == nil {
|
||||
t.Fatal("expected smol to have been deleted")
|
||||
}
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `!`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
|
||||
|
||||
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
|
||||
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", ``)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
|
||||
|
||||
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
|
||||
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
|
||||
|
||||
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
|
||||
s, logs := captureLogs(t, s)
|
||||
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
|
||||
got = s.sendRequest(t, req)
|
||||
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
|
||||
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
|
||||
check(err)
|
||||
if !ok {
|
||||
t.Logf("logs:\n%s", logs)
|
||||
t.Fatalf("expected log to contain ERROR with invalid argument")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUnknownPath(t *testing.T) {
|
||||
s := newTestServer(t)
|
||||
got := s.send(t, "DELETE", "/api/unknown", `{}`)
|
||||
checkErrorResponse(t, got, 404, "not_found", "not found")
|
||||
}
|
||||
|
||||
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
|
||||
t.Helper()
|
||||
|
||||
var printedBody bool
|
||||
errorf := func(format string, args ...any) {
|
||||
t.Helper()
|
||||
if !printedBody {
|
||||
t.Logf("BODY:\n%s", got.Body.String())
|
||||
printedBody = true
|
||||
}
|
||||
t.Errorf(format, args...)
|
||||
}
|
||||
|
||||
if got.Code != status {
|
||||
errorf("Code = %d; want %d", got.Code, status)
|
||||
}
|
||||
|
||||
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
|
||||
var e *ollama.Error
|
||||
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
|
||||
errorf("unmarshal error: %v", err)
|
||||
t.FailNow()
|
||||
}
|
||||
if e.Code != code {
|
||||
errorf("Code = %q; want %q", e.Code, code)
|
||||
}
|
||||
if !strings.Contains(e.Message, msg) {
|
||||
errorf("Message = %q; want to contain %q", e.Message, msg)
|
||||
}
|
||||
}
|
||||
Binary file not shown.
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
1
server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
vendored
Normal file
1
server/internal/registry/testdata/models/manifests/registry.ollama.ai/library/smol/latest
vendored
Normal file
@@ -0,0 +1 @@
|
||||
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}
|
||||
@@ -1,12 +1,40 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
|
||||
func LogWriter(t *testing.T) io.Writer {
|
||||
return testWriter{t}
|
||||
}
|
||||
|
||||
type testWriter struct{ t *testing.T }
|
||||
|
||||
func (w testWriter) Write(b []byte) (int, error) {
|
||||
w.t.Logf("%s", b)
|
||||
return len(b), nil
|
||||
}
|
||||
|
||||
// Slogger returns a [*slog.Logger] that writes each message
|
||||
// using t.Log.
|
||||
func Slogger(t *testing.T) *slog.Logger {
|
||||
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
|
||||
}
|
||||
|
||||
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
|
||||
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
|
||||
var buf bytes.Buffer
|
||||
lg = slog.New(slog.NewTextHandler(&buf, nil))
|
||||
return lg, &buf
|
||||
}
|
||||
|
||||
// Check calls t.Fatal(err) if err is not nil.
|
||||
func Check(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
@@ -34,6 +34,9 @@ import (
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/model/models/mllama"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/server/internal/registry"
|
||||
"github.com/ollama/ollama/template"
|
||||
"github.com/ollama/ollama/types/errtypes"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
@@ -1126,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
|
||||
corsConfig := cors.DefaultConfig()
|
||||
corsConfig.AllowWildcard = true
|
||||
corsConfig.AllowBrowserExtensions = true
|
||||
@@ -1165,10 +1168,9 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
|
||||
|
||||
// Local model cache management
|
||||
// Local model cache management (new implementation is at end of function)
|
||||
r.POST("/api/pull", s.PullHandler)
|
||||
r.POST("/api/push", s.PushHandler)
|
||||
r.DELETE("/api/delete", s.DeleteHandler)
|
||||
r.HEAD("/api/tags", s.ListHandler)
|
||||
r.GET("/api/tags", s.ListHandler)
|
||||
r.POST("/api/show", s.ShowHandler)
|
||||
@@ -1193,7 +1195,15 @@ func (s *Server) GenerateRoutes() http.Handler {
|
||||
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
|
||||
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
|
||||
|
||||
return r
|
||||
// wrap old with new
|
||||
rs := ®istry.Local{
|
||||
Cache: c,
|
||||
Client: rc,
|
||||
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
|
||||
Fallback: r,
|
||||
}
|
||||
|
||||
return rs, nil
|
||||
}
|
||||
|
||||
func Serve(ln net.Listener) error {
|
||||
@@ -1246,12 +1256,27 @@ func Serve(ln net.Listener) error {
|
||||
}
|
||||
}
|
||||
|
||||
s := &Server{addr: ln.Addr()}
|
||||
|
||||
c, err := ollama.DefaultCache()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rc, err := ollama.DefaultRegistry()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h, err := s.GenerateRoutes(c, rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
http.Handle("/", h)
|
||||
|
||||
ctx, done := context.WithCancel(context.Background())
|
||||
schedCtx, schedDone := context.WithCancel(ctx)
|
||||
sched := InitScheduler(schedCtx)
|
||||
s := &Server{addr: ln.Addr(), sched: sched}
|
||||
|
||||
http.Handle("/", s.GenerateRoutes())
|
||||
s.sched = sched
|
||||
|
||||
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
|
||||
srvr := &http.Server{
|
||||
|
||||
@@ -23,6 +23,8 @@ import (
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/ollama/ollama/openai"
|
||||
"github.com/ollama/ollama/server/internal/cache/blob"
|
||||
"github.com/ollama/ollama/server/internal/client/ollama"
|
||||
"github.com/ollama/ollama/types/model"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -91,7 +93,15 @@ func equalStringSlices(a, b []string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func Test_Routes(t *testing.T) {
|
||||
type panicTransport struct{}
|
||||
|
||||
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
|
||||
panic("unexpected RoundTrip call")
|
||||
}
|
||||
|
||||
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
|
||||
|
||||
func TestRoutes(t *testing.T) {
|
||||
type testCase struct {
|
||||
Name string
|
||||
Method string
|
||||
@@ -241,10 +251,10 @@ func Test_Routes(t *testing.T) {
|
||||
Method: http.MethodDelete,
|
||||
Path: "/api/delete",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
createTestModel(t, "model-to-delete")
|
||||
createTestModel(t, "model_to_delete")
|
||||
|
||||
deleteReq := api.DeleteRequest{
|
||||
Name: "model-to-delete",
|
||||
Name: "model_to_delete",
|
||||
}
|
||||
jsonData, err := json.Marshal(deleteReq)
|
||||
if err != nil {
|
||||
@@ -271,7 +281,7 @@ func Test_Routes(t *testing.T) {
|
||||
Path: "/api/delete",
|
||||
Setup: func(t *testing.T, req *http.Request) {
|
||||
deleteReq := api.DeleteRequest{
|
||||
Name: "non-existent-model",
|
||||
Name: "non_existent_model",
|
||||
}
|
||||
jsonData, err := json.Marshal(deleteReq)
|
||||
if err != nil {
|
||||
@@ -477,10 +487,34 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
t.Setenv("OLLAMA_MODELS", t.TempDir())
|
||||
modelsDir := t.TempDir()
|
||||
t.Setenv("OLLAMA_MODELS", modelsDir)
|
||||
|
||||
c, err := blob.Open(modelsDir)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open models dir: %v", err)
|
||||
}
|
||||
|
||||
rc := &ollama.Registry{
|
||||
// This is a temporary measure to allow us to move forward,
|
||||
// surfacing any code contacting ollama.com we do not intended
|
||||
// to.
|
||||
//
|
||||
// Currently, this only handles DELETE /api/delete, which
|
||||
// should not make any contact with the ollama.com registry, so
|
||||
// be clear about that.
|
||||
//
|
||||
// Tests that do need to contact the registry here, will be
|
||||
// consumed into our new server/api code packages and removed
|
||||
// from here.
|
||||
HTTPClient: panicOnRoundTrip,
|
||||
}
|
||||
|
||||
s := &Server{}
|
||||
router := s.GenerateRoutes()
|
||||
router, err := s.GenerateRoutes(c, rc)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate routes: %v", err)
|
||||
}
|
||||
|
||||
httpSrv := httptest.NewServer(router)
|
||||
t.Cleanup(httpSrv.Close)
|
||||
|
||||
Reference in New Issue
Block a user