Compare commits

..

31 Commits

Author SHA1 Message Date
Jeffrey Morgan
e41c4cbea7 build: install ccache manually in Dockerfile (#9464)
Reverts ccache installation to be done manually via curl instead of
using the dnf package manager as this has side effects of prepending
ccache's install directory to the front of the PATH
2025-03-02 16:48:31 -08:00
Blake Mizerany
ee048b76d4 server/internal/client/ollama: handle extended names in client/ollama (#9454)
The extended name format is a superset of the name format that only the
client needs to know about, not the server or other dependents of the
name package, so move the split logic into the client package.

Also, take advantage of knowing about the extended name format to allow
the client to use the extended name format when unlinking to verify they
are unlinking the manifest with the content they intend.
2025-03-02 13:30:41 -08:00
Soulter
af68d60a58 readme: add AstrBot to community integrations (#9442) 2025-03-01 21:58:34 -08:00
Jesse Gross
21aa666a1e ml: Enable support for flash attention
The GGML flash attention kernel has specific requirements for
padding and permutation. This adds support to the KV cache
for conforming to these requirements so that flash attention
can be enabled.

Flash attention can be used in the same situations as the llama
engine and is enabled by the user in the same way.
2025-03-01 20:53:23 -08:00
Jesse Gross
ee141cc821 ml: Empty tensor constructor for tensors
In cases where we allocate a tensor and then fully overwrite it with
copied data, it is wasteful to first zero out the memory.
2025-03-01 20:53:23 -08:00
Jesse Gross
55e5776c44 ggml-backend: Store parent backend as part of tensor
It can be important for a tensor to know what backend it came from -
for example, to know if flash attention is enabled.
2025-03-01 20:53:23 -08:00
Jesse Gross
854a9195f3 attention: Remove unnecessary contiguous operations
Prior to performing attention, we need to permute query, key
and value. Currently we call Contiguous after each of these
permutations, which is correct but expensive. Avoiding the
3 calls to Contiguous increases performance by over 20%.

The permutations of query and key do not violate the continuity
rules for mulmat and the Contiguous call can be simply removed.

Value requires a different permutation and does require Contiguous.
However, we can use the copy into the cache as a way to perform this
without further overhead.

To support this and avoid unexpected tensor shapes that are seen by
models, we need tighter integration between attention, cache
and backend. Future optimization will also likely need this structure
 - for example, flash attention has special padding requirements in
the cache and other backends may have their own needs.

This further contains the operations that go into attention so that
these and other optimizations can be handled transparently. Models
that have special requirements for attention can still implement
their own version of it.
2025-03-01 20:53:23 -08:00
Jeffrey Morgan
96a97adf9b build: use correct GGML_HIP_NO_VMM compiler definition for ggml-hip (#9451) 2025-03-01 17:00:31 -08:00
Jeffrey Morgan
e75c6126e9 build: set GGML_CUDA_NO_VMM for ggml-hip target (#9449) 2025-03-01 14:02:19 -08:00
Blake Mizerany
cda6f5c66c server/internal/internal/names: validate names (#9400)
This commit is a step towards a goal to make names less ceremonial
outside of the registry client. Clients of the registry package can
treat names as opaque strings, and the registry package will handle
parsing, validating, and normalizing names.

Ideally we end up with the names package tucked away in an internal
package for good. We'll see how things go.

Also, this package name is not permanent. This another step in the
on-going process of refactoring the server code, and at some point it
will most likely be renamed/moved.
2025-03-01 13:15:14 -08:00
Bruce MacDonald
bebb6823c0 server: validate local path on safetensor create (#9379)
More validation during the safetensor creation process.
Properly handle relative paths (like ./model.safetensors) while rejecting absolute paths
Add comprehensive test coverage for various paths
No functionality changes for valid inputs - existing workflows remain unaffected
Leverages Go 1.24's new os.Root functionality for secure containment
2025-02-28 16:10:43 -08:00
Michael Yang
31e472baa4 runner: defer context cancel
defer the cancel to guarantee it runs
2025-02-28 22:27:28 +00:00
Michael Yang
657685e85d fix: replace deprecated functions 2025-02-28 21:29:34 +00:00
Jeffrey Morgan
a14912858e build: add compute capability 12.0 to CUDA 12 preset (#9426)
Focuses initial Blackwell support on compute capability 12.0
which includes the 50x series of GeForce cards. In the future
additional compute capabilities may be added
2025-02-28 13:12:31 -08:00
Blake Mizerany
eed11ded30 server/.../safetensors: fix offsets and include all model parts (#9427)
Also, require the -as flag to be set when importing a model. This
prevents the confusing error message "invalid name".

Also, allow short names to be used when importing a model and
auto-complete the name with the default mask.
2025-02-28 13:08:10 -08:00
Michael Yang
b42aba40ed cuda: enable flash attention
ggml added an option to disable flash attention so explicitly enable it
2025-02-28 19:40:34 +00:00
王贺
25885e5335 docs: Add 1Panel to Community Integrations (#9312) 2025-02-28 09:53:03 -08:00
Jeffrey Morgan
98d44fa39d llama: add phi4 mini support (#9403) 2025-02-27 19:30:32 -08:00
Blake Mizerany
2099e2d267 CONTRIBUTING: provide clarity on good commit messages, and bad (#9405)
Also, our commit messages have been getting better, but we can do
better, and be more consistent. This adds more clarity on how to write
commit messages and provides examples of good and bad messages.

Also, our contributing guide was lacking helpful guidance on how to
start change proposals. This commit adds the start of that section.

Soon, we should add a proposal template to the issue tracker with a link
back to the proposal section, which should also be expanded upon.
2025-02-27 19:22:26 -08:00
Bruce MacDonald
0c1041ad85 runner: default to greedy sampler for performance (#9407)
As are adding support for weighted sampling we have seen some performance
regressions, bypassing the sampler logic for now and defaulting to greedy
until we can benchmark the new sampler logic.
2025-02-27 16:41:20 -08:00
Parth Sareen
c245b0406f sample: remove transforms from greedy sampling (#9377) 2025-02-27 15:44:53 -08:00
Michael Yang
8b194b7520 kvcache: update tests 2025-02-27 22:27:16 +00:00
Michael Yang
3e8b8a1933 ml: update Context.Forward interface
update Context.Forward to accept multiple tensors to match
Context.Compute signature

update Context.Forward to return Context such that it can be chained
with Context.Compute
2025-02-27 22:27:16 +00:00
Blake Mizerany
41dc280491 server/internal/registry: implement CloseNotify and Flush (for now) (#9402)
This fixes panics introduced in 2412adf42b
when Gin ungracefully assumes that the http.ResponseWriter implements
http.CloseNotifier and http.Flusher, which our new statusCodeRecorder
does not. This is a temporary fix until we can pour the rest of the Gin
out.
2025-02-27 14:00:37 -08:00
Michael Yang
53d2990d9b model: add bos token if configured 2025-02-27 21:04:59 +00:00
Jesse Gross
e185c08ad9 go.mod: Use full version for go 1.24.0
Otherwise on Linux I get:
go: download go1.24 for linux/amd64: toolchain not available
2025-02-27 13:01:32 -08:00
Blake Mizerany
2412adf42b server/internal: replace model delete API with new registry handler. (#9347)
This commit introduces a new API implementation for handling
interactions with the registry and the local model cache. The new API is
located in server/internal/registry. The package name is "registry" and
should be considered temporary; it is hidden and not bleeding outside of
the server package. As the commits roll in, we'll start consuming more
of the API and then let reverse osmosis take effect, at which point it
will surface closer to the root level packages as much as needed.
2025-02-27 12:04:53 -08:00
Steven Hartland
be2ac1ed93 docs: fix api examples link (#9360)
Fix the examples link in the go package documentation for the API.
2025-02-27 10:51:12 -08:00
Eries Trisnadi
dc13813a03 server: allow vscode-file origins (#9313) 2025-02-27 10:39:43 -08:00
Michael Yang
d6af13efed runner: simplify tensor split parsing 2025-02-27 18:36:46 +00:00
Michael Yang
a59f665235 ml/backend/ggml: fix debug logging 2025-02-27 18:30:57 +00:00
53 changed files with 1932 additions and 524 deletions

View File

@@ -23,6 +23,7 @@ set(GGML_SCHED_MAX_COPIES 4)
set(GGML_LLAMAFILE ON)
set(GGML_CUDA_PEER_MAX_BATCH_SIZE 128)
set(GGML_CUDA_GRAPHS ON)
set(GGML_CUDA_FA ON)
if((CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_OSX_ARCHITECTURES MATCHES "arm64")
OR (NOT CMAKE_OSX_ARCHITECTURES AND NOT CMAKE_SYSTEM_PROCESSOR MATCHES "arm|aarch64|ARM64|ARMv[0-9]+"))
@@ -105,9 +106,11 @@ if(CMAKE_HIP_COMPILER)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-hip)
if (WIN32)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY=1)
target_compile_definitions(ggml-hip PRIVATE GGML_CUDA_NO_PEER_COPY)
endif()
target_compile_definitions(ggml-hip PRIVATE GGML_HIP_NO_VMM)
set(OLLAMA_HIP_INSTALL_DIR ${OLLAMA_INSTALL_DIR}/rocm)
install(TARGETS ggml-hip
RUNTIME_DEPENDENCIES

View File

@@ -28,7 +28,7 @@
"name": "CUDA 12",
"inherits": [ "CUDA" ],
"cacheVariables": {
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;100"
"CMAKE_CUDA_ARCHITECTURES": "50;60;61;70;75;80;86;87;89;90;90a;120"
}
},
{

View File

@@ -6,8 +6,6 @@ Thank you for your interest in contributing to Ollama! Here are a few guidelines
See the [development documentation](./docs/development.md) for instructions on how to build and run Ollama locally.
## Pull requests
### Ideal issues
* [Bugs](https://github.com/ollama/ollama/issues?q=is%3Aissue+is%3Aopen+label%3Abug): issues where Ollama stops working or where it results in an unexpected error.
@@ -26,11 +24,64 @@ See the [development documentation](./docs/development.md) for instructions on h
* Changes that add significant friction to the user experience
* Changes that create a large future maintenance burden for maintainers and contributors
### Best practices
## Proposing a (non-trivial) change
* Commit messages: please leave both a title and a description in your commit messages. The title should be a short summary of the changes, with a leading word that explains the section of the code being changed (e.g. `api: fix parsing of prompt field`) . In the description, leave a short 2-3 sentences that explain more about the change and its impact.
* Tests: please add test coverage to changes where possible.
* Minimize dependencies: avoid adding new dependencies unless absolutely necessary.
> By "non-trivial", we mean a change that is not a bug fix or small
> documentation update. If you are unsure, please ask us on our [Discord
> server](https://discord.gg/ollama).
Before opening a non-trivial Pull Request, please open an issue to discuss the change and
get feedback from the maintainers. This helps us understand the context of the
change and how it fits into Ollama's roadmap and prevents us from duplicating
work or you from spending time on a change that we may not be able to accept.
Tips for proposals:
* Explain the problem you are trying to solve, not what you are trying to do.
* Explain why the change is important.
* Explain how the change will be used.
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the change were accepted.
## Pull requests
**Commit messages**
The title should look like:
<package>: <short description>
The package is the most affected Go package. If the change does not affect Go
code, then use the directory name instead. Changes to a single well-known
file in the root directory may use the file name.
The short description should start with a lowercase letter and be a
continuation of the sentence:
"This changes Ollama to..."
Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clairity on good commit messages, and bad
Bad Examples:
feat: add more emoji
fix: was not using famous web framework
chore: generify code
**Tests**
Please include tests. Strive to test behavior, not implementation.
**New dependencies**
Dependencies should be added sparingly. If you are adding a new dependency,
please explain why it is necessary and what other ways you attempted that
did not work without it.
## Need help?

View File

@@ -12,8 +12,9 @@ FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base
RUN yum install -y yum-utils \
&& yum-config-manager --add-repo https://dl.rockylinux.org/vault/rocky/8.5/AppStream/\$basearch/os/ \
&& rpm --import https://dl.rockylinux.org/pub/rocky/RPM-GPG-KEY-Rocky-8 \
&& dnf install -y yum-utils ccache gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
&& dnf install -y yum-utils gcc-toolset-10-gcc-10.2.1-8.2.el8 gcc-toolset-10-gcc-c++-10.2.1-8.2.el8 \
&& yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo \
&& curl -s -L https://github.com/ccache/ccache/releases/download/v4.10.2/ccache-4.10.2-linux-x86_64.tar.xz | tar -Jx -C /usr/local/bin --strip-components 1
ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH
FROM --platform=linux/arm64 almalinux:8 AS base-arm64

View File

@@ -386,6 +386,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [MaxKB](https://github.com/1Panel-dev/MaxKB/) (Ready-to-use & flexible RAG Chatbot)
- [yla](https://github.com/danielekp/yla) (Web interface to freely interact with your customized models)
- [LangBot](https://github.com/RockChinQ/LangBot) (LLM-based instant messaging bots platform, with Agents, RAG features, supports multiple platforms)
- [1Panel](https://github.com/1Panel-dev/1Panel/) (Web-based Linux Server Management Tool)
- [AstrBot](https://github.com/Soulter/AstrBot/) (User-friendly LLM-based multi-platform chatbot with a WebUI, supporting RAG, LLM agents, and plugins integration)
### Cloud

View File

@@ -10,7 +10,7 @@
// repository].
//
// [the API documentation]: https://github.com/ollama/ollama/blob/main/docs/api.md
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/examples
// [in the GitHub repository]: https://github.com/ollama/ollama/tree/main/api/examples
package api
import (

View File

@@ -73,6 +73,7 @@ func AllowedOrigins() (origins []string) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
)
return origins

View File

@@ -69,6 +69,7 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://10.0.0.1", []string{
"http://10.0.0.1",
@@ -88,6 +89,7 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://172.16.0.1,https://192.168.0.1", []string{
"http://172.16.0.1",
@@ -108,6 +110,7 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
{"http://totally.safe,http://definitely.legit", []string{
"http://totally.safe",
@@ -128,6 +131,7 @@ func TestOrigins(t *testing.T) {
"file://*",
"tauri://*",
"vscode-webview://*",
"vscode-file://*",
}},
}
for _, tt := range cases {

View File

@@ -100,6 +100,10 @@ func (kv KV) Float(key string, defaultValue ...float32) float32 {
return keyValue(kv, key, append(defaultValue, 0)...)
}
func (kv KV) Bool(key string, defaultValue ...bool) bool {
return keyValue(kv, key, append(defaultValue, false)...)
}
func (kv KV) Strings(key string, defaultValue ...[]string) []string {
r := keyValue(kv, key, &array{})
s := make([]string, r.size)
@@ -120,7 +124,7 @@ func (kv KV) Uints(key string, defaultValue ...[]uint32) []uint32 {
return s
}
func keyValue[T string | uint32 | uint64 | float32 | *array](kv KV, key string, defaultValue ...T) T {
func keyValue[T string | uint32 | uint64 | float32 | *array | bool](kv KV, key string, defaultValue ...T) T {
if !strings.HasPrefix(key, "tokenizer.") && !strings.HasPrefix(key, "general.") {
key = kv.Architecture() + "." + key
}

16
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/ollama/ollama
go 1.24
go 1.24.0
require (
github.com/containerd/console v1.0.3
@@ -11,7 +11,7 @@ require (
github.com/spf13/cobra v1.7.0
github.com/stretchr/testify v1.9.0
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.10.0
golang.org/x/sync v0.11.0
)
require (
@@ -69,12 +69,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.31.0
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
golang.org/x/net v0.25.0 // indirect
golang.org/x/sys v0.28.0
golang.org/x/term v0.27.0
golang.org/x/text v0.21.0
golang.org/x/crypto v0.33.0
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa
golang.org/x/net v0.35.0 // indirect
golang.org/x/sys v0.30.0
golang.org/x/term v0.29.0
golang.org/x/text v0.22.0
google.golang.org/protobuf v1.34.1
gopkg.in/yaml.v3 v3.0.1 // indirect
)

28
go.sum
View File

@@ -214,16 +214,16 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.31.0 h1:ihbySMvVjLAeSH1IbfcRTkD/iNscyz8rGzjF/E5hV6U=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.33.0 h1:IOBPskki6Lysi0lo9qQvbxiQ+FvsCC/YWOecCHAixus=
golang.org/x/crypto v0.33.0/go.mod h1:bVdXmD7IV/4GdElGPozy6U7lWdRXA4qyRVGJV57uQ5M=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa h1:t2QcU6V556bFjYgu4L6C+6VrCPyJZ+eyRsABUPs1mz4=
golang.org/x/exp v0.0.0-20250218142911-aa4b98e5adaa/go.mod h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0=
@@ -257,8 +257,8 @@ golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81R
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM=
golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.25.0 h1:d/OCCoBEUq33pjydKrGQhw7IlUPI2Oylr+8qLx49kac=
golang.org/x/net v0.25.0/go.mod h1:JkAGAh7GEvH74S6FOH42FLoXpXbE/aqXSrIQjXgsiwM=
golang.org/x/net v0.35.0 h1:T5GQRQb2y08kTAByq9L4/bz8cipCdA8FbRTXewonqY8=
golang.org/x/net v0.35.0/go.mod h1:EglIi67kWsHKlRzzVMUD93VMSWGFOMSZgxFjparz1Qk=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -268,8 +268,8 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sync v0.11.0 h1:GGz8+XQP4FvTTrjZPzNKTMFtSXH80RAzG+5ghFPgK9w=
golang.org/x/sync v0.11.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
@@ -285,17 +285,17 @@ golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.28.0 h1:Fksou7UEQUWlKvIdsqzJmUmCX3cZuD2+P3XyyzwMhlA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc=
golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.27.0 h1:WP60Sv1nlK1T6SupCHbXzSaN0b9wUmsPoRS9b61A23Q=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.29.0 h1:L6pJp37ocefwRRtYPKSWOWzOtWSxVajvz2ldH/xi3iU=
golang.org/x/term v0.29.0/go.mod h1:6bl4lRlvVuDgSf3179VpIxBF0o10JUpXWOnI7nErv7s=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.22.0 h1:bofq7m3/HAFvbF51jz3Q9wLg3jkvSPuiZu/pD1XwgtM=
golang.org/x/text v0.22.0/go.mod h1:YRoo4H8PVmsu+E3Ou7cqLVH8oXWIHVoX0jqUWALQhfY=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -29,6 +29,17 @@ type Cache interface {
// cache implementation used.
Put(ctx ml.Context, key, value ml.Tensor)
// SetConfig controls optimizations (mostly backend-specific) that may transform
// the output of the cache to work better with specific kernels. If not called,
// the backend settings will be used. This works well when calling Attention.
//
// The config can be overridden by models, especially if they require vanilla
// output when implementing their own version of attention. To do this, pass
// an empty ml.CacheConfig.
//
// Most models will not need to use this.
SetConfig(ml.CacheConfig)
// ** cache management **
// Init sets up runtime parameters

View File

@@ -22,6 +22,9 @@ type Causal struct {
Capacity int32
windowSize int32
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -75,14 +78,42 @@ func NewSWACache(windowSize int32, shift shiftFn) *Causal {
}
func (c *Causal) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding == 0 {
c.config.CachePadding = 1
}
if c.config.MaskBatchPadding == 0 {
c.config.MaskBatchPadding = 1
}
if c.config.MaskDType == ml.DTypeOther {
c.config.MaskDType = ml.DTypeF32
}
c.DType = dtype
c.Capacity = capacity
c.cells = make([]cacheCell, capacity)
c.Capacity = int32(roundUp(int(capacity), c.config.CachePadding))
c.cells = make([]cacheCell, c.Capacity)
c.cellRanges = make(map[int]cellRange)
c.backend = backend
c.cacheCtx = backend.NewContext()
}
func (c *Causal) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *Causal) Close() {
c.cacheCtx.Close()
}
@@ -157,36 +188,91 @@ func (c *Causal) findStartLoc() (int, error) {
return 0, fmt.Errorf("%w (length: %v)", ErrKvCacheFull, c.Capacity)
}
func roundDown(length, pad int) int {
return (length / pad) * pad
}
func roundUp(length, pad int) int {
return ((length + pad - 1) / pad) * pad
}
// Builds a mask of history x batch indicating whether for each token in the batch the
// token in the history should apply. This is based on both the sequence and causality (the
// position of the history is not ahead of the token in the batch).
func (c *Causal) buildMask(ctx ml.Context, positions []int32, seqs []int) (ml.Tensor, error) {
// TODO(jessegross): This does not do padding, which is required for flash attention
len := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, c.curBatchSize*len)
// Align and pad the two dimensions as required by the backend
batchSize := roundUp(c.curBatchSize, c.config.MaskBatchPadding)
c.curCellRange.min = roundDown(c.curCellRange.min, c.config.CachePadding)
c.curCellRange.max = roundUp(c.curCellRange.max+1, c.config.CachePadding) - 1
length := c.curCellRange.max - c.curCellRange.min + 1
mask := make([]float32, batchSize*length)
for i := range c.curBatchSize {
for j := c.curCellRange.min; j <= c.curCellRange.max; j++ {
if !slices.Contains(c.cells[j].sequences, seqs[i]) || c.cells[j].pos > positions[i] ||
c.cells[j].pos < positions[i]-c.windowSize {
mask[i*len+(j-c.curCellRange.min)] = float32(math.Inf(-1))
mask[i*length+(j-c.curCellRange.min)] = float32(math.Inf(-1))
}
}
}
return ctx.FromFloatSlice(mask, len, c.curBatchSize)
// Mask out any padding tokens we added. For padding that we added to the cache history, this
// has already been masked out because the sequence doesn't match.
for i := c.curBatchSize * length; i < len(mask); i++ {
mask[i] = float32(math.Inf(-1))
}
maskTensor, err := ctx.FromFloatSlice(mask, length, batchSize)
if err != nil {
return nil, err
}
if c.config.MaskDType != ml.DTypeF32 {
out := ctx.Empty(c.config.MaskDType, maskTensor.Shape()...)
ctx.Forward(maskTensor.Copy(ctx, out))
maskTensor = out
}
return maskTensor, nil
}
func moveCell(ctx ml.Context, objs []ml.Tensor, src, dst, len int) {
for _, obj := range objs {
if obj == nil {
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
for i := range c.keys {
if c.keys[i] == nil {
continue
}
srcView := obj.View(ctx, obj.Stride(2)*src, obj.Dim(0)*obj.Dim(1)*len)
dstView := obj.View(ctx, obj.Stride(2)*dst, obj.Dim(0)*obj.Dim(1)*len)
key := c.keys[i]
ctx.Forward(srcView.Copy(ctx, dstView))
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*len)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*len)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, len, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*len)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*len)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
@@ -238,8 +324,7 @@ func (c *Causal) defrag() {
pendingLen++
break
} else {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
@@ -263,8 +348,7 @@ func (c *Causal) defrag() {
}
if pendingLen > 0 {
moveCell(ctx, c.keys, pendingSrc, pendingDst, pendingLen)
moveCell(ctx, c.values, pendingSrc, pendingDst, pendingLen)
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
@@ -305,33 +389,73 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
key := c.keys[c.curLayer]
value := c.values[c.curLayer]
key = key.View(ctx, key.Stride(2)*c.curCellRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
c.curMask.Dim(0),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
cachedSize := c.curMask.Dim(0)
key = key.View(ctx, rowSize*c.curCellRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
cachedSize,
)
value = value.View(ctx, key.Stride(2)*c.curCellRange.min,
value.Dim(0), value.Stride(1),
value.Dim(1), value.Stride(2),
c.curMask.Dim(0),
)
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
value = value.View(ctx, elemSize*c.curCellRange.min,
cachedSize, value.Stride(1),
vHeadDim, value.Stride(2),
numKVHeads,
)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
value = value.View(ctx, rowSize*c.curCellRange.min,
vHeadDim, value.Stride(1),
numKVHeads, value.Stride(2),
cachedSize,
)
}
return key, value, c.curMask
}
func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
if c.curBatchSize != key.Dim(2) {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, key.Dim(2)))
kHeadDim := key.Dim(0)
vHeadDim := value.Dim(0)
numKVHeads := key.Dim(1)
batchSize := key.Dim(2)
if c.curBatchSize != batchSize {
panic(fmt.Errorf("inconsistent batch sizes (layer: %v, batch size: %v layer batch size: %v)", c.curLayer, c.curBatchSize, batchSize))
}
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, key.Dim(0), key.Dim(1), int(c.Capacity))
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, value.Dim(0), value.Dim(1), int(c.Capacity))
c.keys[c.curLayer] = c.cacheCtx.Zeros(c.DType, kHeadDim, numKVHeads, int(c.Capacity))
if c.config.PermutedV {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, int(c.Capacity), vHeadDim, numKVHeads)
} else {
c.values[c.curLayer] = c.cacheCtx.Zeros(c.DType, vHeadDim, numKVHeads, int(c.Capacity))
}
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, c.keys[c.curLayer].Stride(2)*c.curLoc, key.Dim(0)*key.Dim(1)*key.Dim(2))))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, c.values[c.curLayer].Stride(2)*c.curLoc, value.Dim(0)*value.Dim(1)*value.Dim(2))))
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
elemSize := c.values[c.curLayer].Stride(0)
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, int(c.Capacity)*elemSize, vHeadDim*numKVHeads)))
} else {
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}
func (c *Causal) CopyPrefix(srcSeq, dstSeq int, len int32) {
@@ -387,9 +511,13 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
continue
}
key = key.View(ctx, key.Stride(2)*seqRange.min,
key.Dim(0), key.Stride(1),
key.Dim(1), key.Stride(2),
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
key = key.View(ctx, rowSize*seqRange.min,
kHeadDim, key.Stride(1),
numKVHeads, key.Stride(2),
size,
)

View File

@@ -280,9 +280,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase)
out, _, mask := cache.Get(context)
context.Forward(out)
context.Forward(mask)
context.Compute(out, mask)
context.Forward(out, mask).Compute(out, mask)
if !slices.Equal(out.Floats(), test.expected) || !slices.Equal(out.Shape(), test.expectedShape) || !slices.Equal(mask.Floats(), test.expectedMask) {
t.Errorf("TestCache: have %v (shape %v); want %v (shape %v); mask: have %v (shape %v) want %v", out.Floats(), out.Shape(), test.expected, test.expectedShape, mask.Floats(), mask.Shape(), test.expectedMask)
@@ -311,7 +309,7 @@ func (b *testBackend) SystemInfo() string {
type testContext struct{}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func (c *testContext) Empty(dtype ml.DType, shape ...int) ml.Tensor {
total := 0
if len(shape) > 0 {
@@ -324,8 +322,12 @@ func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return &testTensor{dtype: dtype, elementSize: 4, data: make([]float32, total), shape: shape}
}
func (c *testContext) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return c.Empty(dtype, shape...)
}
func (c *testContext) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
t := c.Zeros(ml.DTypeF32, shape...).(*testTensor)
t := c.Empty(ml.DTypeF32, shape...).(*testTensor)
copy(t.data, s)
@@ -344,7 +346,7 @@ func (c *testContext) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) {
return out, nil
}
func (c *testContext) Forward(ml.Tensor) {}
func (c *testContext) Forward(...ml.Tensor) ml.Context { return c }
func (c *testContext) Compute(...ml.Tensor) {}
@@ -393,7 +395,7 @@ func (t *testTensor) Floats() []float32 {
}
func (t *testTensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
out := ctx.Zeros(t.DType(), t.Shape()...).(*testTensor)
out := ctx.Empty(t.DType(), t.Shape()...).(*testTensor)
for i := range out.data {
out.data[i] = t.data[i] + t2.(*testTensor).data[i]
@@ -470,7 +472,7 @@ func (t *testTensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
context := &testContext{}
view := context.Zeros(t.dtype, s...).(*testTensor)
view := context.Empty(t.dtype, s...).(*testTensor)
view.data = t.data[offset : offset+len(view.data)]
return view

View File

@@ -1,6 +1,8 @@
package kvcache
import (
"fmt"
"github.com/ollama/ollama/ml"
)
@@ -11,6 +13,9 @@ import (
//
// Not currently safe for multiple sequences
type EncoderCache struct {
// config controls mostly backend-specific optimizations
config *ml.CacheConfig
// ** current forward pass **
// the active layer for Get and Put
@@ -40,9 +45,29 @@ func NewEncoderCache() *EncoderCache {
}
func (c *EncoderCache) Init(backend ml.Backend, dtype ml.DType, capacity int32) {
if c.config == nil {
var config ml.CacheConfig
if cc, ok := backend.(ml.BackendCacheConfig); ok {
config = cc.CacheConfig()
}
c.config = &config
}
if c.config.CachePadding != 0 && c.config.CachePadding != 1 {
panic(fmt.Errorf("encoder cache is unable to enforce requested CachePadding (%v)", c.config.CachePadding))
}
c.cacheCtx = backend.NewContext()
}
func (c *EncoderCache) SetConfig(config ml.CacheConfig) {
if c.config != nil {
panic("config cannot be changed after being previously set, either by the model or backend")
}
c.config = &config
}
func (c *EncoderCache) Close() {
c.cacheCtx.Close()
}
@@ -75,13 +100,19 @@ func (c *EncoderCache) Put(ctx ml.Context, key, value ml.Tensor) {
c.encoderPos = c.curPos
c.encoderCached = true
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Zeros(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Zeros(value.DType(), value.Shape()...)
if c.config.PermutedV {
value = value.Permute(ctx, 1, 2, 0, 3)
}
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer]))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer]))
if c.keys[c.curLayer] == nil || c.values[c.curLayer] == nil {
c.keys[c.curLayer] = c.cacheCtx.Empty(key.DType(), key.Shape()...)
c.values[c.curLayer] = c.cacheCtx.Empty(value.DType(), value.Shape()...)
}
ctx.Forward(
key.Copy(ctx, c.keys[c.curLayer]),
value.Copy(ctx, c.values[c.curLayer]),
)
}
func (c *EncoderCache) CopyPrefix(srcSeq, dstSeq int, len int32) {

View File

@@ -28,6 +28,12 @@ func (c *WrapperCache) Init(backend ml.Backend, dtype ml.DType, capacity int32)
}
}
func (c *WrapperCache) SetConfig(config ml.CacheConfig) {
for _, cache := range c.caches {
cache.SetConfig(config)
}
}
func (c *WrapperCache) Close() {
for _, cache := range c.caches {
cache.Close()

View File

@@ -105,6 +105,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
};
enum llama_rope_type {

View File

@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}
} break;
case LLM_ARCH_PHIMOE:

View File

@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
case LLAMA_VOCAB_PRE_TYPE_GPT4O:
// original regex from tokenizer.json
// [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
regex_exprs = {
"[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} else if (
tokenizer_pre == "megrez") {
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
} else if (
tokenizer_pre == "gpt-4o") {
pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
clean_spaces = false;
} else {
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

View File

@@ -37,23 +37,36 @@ COMPILER inline get_compiler() {
import "C"
import (
"context"
_ "embed"
"errors"
"fmt"
"log/slog"
"os"
"runtime"
"runtime/cgo"
"slices"
"strings"
"sync/atomic"
"unsafe"
_ "github.com/ollama/ollama/llama/llama.cpp/common"
_ "github.com/ollama/ollama/llama/llama.cpp/examples/llava"
_ "github.com/ollama/ollama/llama/llama.cpp/src"
"github.com/ollama/ollama/ml/backend/ggml/ggml/src"
ggml "github.com/ollama/ollama/ml/backend/ggml/ggml/src"
)
func init() {
C.llama_log_set(C.ggml_log_callback(C.llamaLog), nil)
}
//export llamaLog
func llamaLog(level C.int, text *C.char, _ unsafe.Pointer) {
// slog levels zeros INFO and are multiples of 4
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
fmt.Fprint(os.Stderr, C.GoString(text))
}
}
func BackendInit() {
ggml.OnceLoad()
C.llama_backend_init()
@@ -72,26 +85,6 @@ func PrintSystemInfo() string {
return C.GoString(C.llama_print_system_info()) + compiler
}
var logLevel atomic.Int32
func init() {
logLevel.Store(int32(C.GGML_LOG_LEVEL_INFO))
C.llama_log_set((C.ggml_log_callback)(C.llamaLog), nil)
}
func EnableDebug() {
logLevel.Store(int32(C.GGML_LOG_LEVEL_DEBUG))
}
//export llamaLog
func llamaLog(level int32, text *C.char, _ unsafe.Pointer) {
if level < logLevel.Load() {
return
}
fmt.Fprint(os.Stderr, C.GoString(text))
}
func GetModelArch(modelPath string) (string, error) {
mp := C.CString(modelPath)
defer C.free(unsafe.Pointer(mp))
@@ -269,7 +262,7 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
cparams.progress_callback_user_data = unsafe.Pointer(&handle)
}
m := Model{c: C.llama_load_model_from_file(C.CString(modelPath), cparams)}
m := Model{c: C.llama_model_load_from_file(C.CString(modelPath), cparams)}
if m.c == nil {
return nil, fmt.Errorf("unable to load model: %s", modelPath)
}
@@ -278,12 +271,12 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) {
}
func FreeModel(model *Model) {
C.llama_free_model(model.c)
C.llama_model_free(model.c)
}
func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
c := Context{
c: C.llama_new_context_with_model(model.c, params.c),
c: C.llama_init_from_model(model.c, params.c),
numThreads: int(params.c.n_threads),
}
if c.c == nil {
@@ -294,15 +287,15 @@ func NewContextWithModel(model *Model, params ContextParams) (*Context, error) {
}
func (m *Model) NumVocab() int {
return int(C.llama_n_vocab(m.Vocab()))
return int(C.llama_vocab_n_tokens(m.Vocab()))
}
func (m *Model) TokenIsEog(token int) bool {
return bool(C.llama_token_is_eog(m.Vocab(), C.llama_token(token)))
return bool(C.llama_vocab_is_eog(m.Vocab(), C.llama_token(token)))
}
func (m *Model) AddBOSToken() bool {
return bool(C.llama_add_bos_token(m.Vocab()))
return bool(C.llama_vocab_get_add_bos(m.Vocab()))
}
func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float32, threads int) error {
@@ -485,7 +478,7 @@ func (m *Model) Tokenize(text string, addSpecial bool, parseSpecial bool) ([]int
}
func (m *Model) NEmbd() int {
return int(C.llama_n_embd(m.c))
return int(C.llama_model_n_embd(m.c))
}
func Quantize(infile, outfile string, ftype uint32) error {

View File

@@ -0,0 +1,80 @@
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
From: jmorganca <jmorganca@gmail.com>
Date: Thu, 27 Feb 2025 15:12:26 -0800
Subject: [PATCH] add phi4 support
---
include/llama.h | 1 +
src/llama-model.cpp | 10 +++++++---
src/llama-vocab.cpp | 11 +++++++++++
3 files changed, 19 insertions(+), 3 deletions(-)
diff --git a/include/llama.h b/include/llama.h
index cc948005..16774711 100644
--- a/include/llama.h
+++ b/include/llama.h
@@ -105,6 +105,7 @@ extern "C" {
LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26,
LLAMA_VOCAB_PRE_TYPE_MINERVA = 27,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28,
+ LLAMA_VOCAB_PRE_TYPE_GPT4O = 29,
};
enum llama_rope_type {
diff --git a/src/llama-model.cpp b/src/llama-model.cpp
index 21819080..ab1a07d1 100644
--- a/src/llama-model.cpp
+++ b/src/llama-model.cpp
@@ -2283,7 +2283,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
- output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+ }
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
@@ -2298,8 +2302,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, 2 * n_ff }, 0);
- layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
- layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_embd_head/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+ layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
}
} break;
case LLM_ARCH_PHIMOE:
diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp
index 1ca827eb..c7ff28be 100644
--- a/src/llama-vocab.cpp
+++ b/src/llama-vocab.cpp
@@ -392,6 +392,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)",
};
break;
+ case LLAMA_VOCAB_PRE_TYPE_GPT4O:
+ // original regex from tokenizer.json
+ // [^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+
+ regex_exprs = {
+ "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+ };
+ break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@@ -1583,6 +1590,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
} else if (
tokenizer_pre == "megrez") {
pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2;
+ } else if (
+ tokenizer_pre == "gpt-4o") {
+ pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O;
+ clean_spaces = false;
} else {
LLAMA_LOG_WARN("%s: missing or unrecognized pre-tokenizer type, using: 'default'\n", __func__);
pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT;

View File

@@ -14,6 +14,7 @@ type Config interface {
String(string, ...string) string
Uint(string, ...uint32) uint32
Float(string, ...float32) float32
Bool(string, ...bool) bool
Strings(string, ...[]string) []string
Uints(string, ...[]uint32) []uint32
@@ -26,6 +27,35 @@ type Backend interface {
SystemInfo() string
}
// BackendCacheConfig should be implemented by backends that need special output
// from the cache to meet specific requirements. It is frequently implemented in
// conjunction with ScaledDotProductAttention.
type BackendCacheConfig interface {
CacheConfig() CacheConfig
}
// CacheConfig controls optimizations (mostly backend-specific) that may transform
// the output the cache to work better with specific kernels.
type CacheConfig struct {
// CachePadding specifies the multiple for the number of tokens of cache history
// that will be returned from cache Get for k, v and mask. The capacity of the
// cache itself will also be increased to a multiple of this size if needed.
CachePadding int
// PermutedV performs Permute(ctx, 1, 2, 0, 3) on v tensors stored via Put
// and return the permuted version via Get. This uses the cache copy operation
// to avoid a Contiguous call on the permuted tensor.
PermutedV bool
// MaskDType specifies the data type for generating the mask. If unset it will
// default to DTypeF32.
MaskDType DType
// MaskBatchPadding specifies the multiple for the batch size dimension in the mask.
// Any position that does not correspond to an actual token will be filled with -Inf.
MaskBatchPadding int
}
// BackendParams controls how the backend loads and executes models
type BackendParams struct {
// NumThreads sets the number of threads to use if running on the CPU
@@ -39,6 +69,9 @@ type BackendParams struct {
// TensorSplit is the fraction of the model to offload to each GPU
TensorSplit []float32
// FlashAttention indicates that we should use a fused flash attention kernel
FlashAttention bool
}
var backends = make(map[string]func(*os.File, BackendParams) (Backend, error))
@@ -60,11 +93,12 @@ func NewBackend(f *os.File, params BackendParams) (Backend, error) {
}
type Context interface {
Empty(dtype DType, shape ...int) Tensor
Zeros(dtype DType, shape ...int) Tensor
FromFloatSlice(s []float32, shape ...int) (Tensor, error)
FromIntSlice(s []int32, shape ...int) (Tensor, error)
Forward(Tensor)
Forward(...Tensor) Context
Compute(...Tensor)
MaxTensors() int
Close()
@@ -115,6 +149,10 @@ type Tensor interface {
// operation equivalent to following code on a tensor named
// query:
//
// query = query.Permute(ctx, 0, 2, 1, 3)
// key = key.Permute(ctx, 0, 2, 1, 3)
// value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
//
// kq := key.MulmatFullPrec(ctx, query)
//
// kq = kq.Scale(ctx, scale)
@@ -169,7 +207,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
})
case DTypeF16:
f32 := ctx.Zeros(DTypeF32, t.Shape()...)
f32 := ctx.Empty(DTypeF32, t.Shape()...)
f32 = t.Copy(ctx, f32)
return dump[[]float32](ctx, f32, opts[0].Items, func(f float32) string {
return strconv.FormatFloat(float64(f), 'f', opts[0].Precision, 32)
@@ -185,8 +223,7 @@ func Dump(ctx Context, t Tensor, opts ...DumpOptions) string {
func dump[S ~[]E, E number](ctx Context, t Tensor, items int, fn func(E) string) string {
if t.Bytes() == nil {
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
}
s := make(S, mul(t.Shape()...))

View File

@@ -79,6 +79,8 @@ var devices = sync.OnceValue(func() []device {
})
type Backend struct {
flashAttention bool
meta *fs.GGML
cpus, gpus []Context
tensors map[string]*Context
@@ -192,9 +194,10 @@ func New(r *os.File, params ml.BackendParams) (ml.Backend, error) {
}
return &Backend{
meta: meta,
cpus: cpus,
gpus: gpus,
flashAttention: params.FlashAttention,
meta: meta,
cpus: cpus,
gpus: gpus,
sched: C.ggml_backend_sched_new(
(*C.ggml_backend_t)(unsafe.Pointer(&backends[0])),
(*C.ggml_backend_buffer_type_t)(unsafe.Pointer(&bufts[0])),
@@ -219,7 +222,7 @@ func (b *Backend) Get(name string) ml.Tensor {
for _, c := range append(b.gpus, b.cpus...) {
if t := C.ggml_get_tensor(c.ctx, cname); t != nil {
return &Tensor{t: t}
return &Tensor{b: b, t: t}
}
}
@@ -247,6 +250,14 @@ func (b *Backend) NewContext() ml.Context {
}
}
func (b *Backend) CacheConfig() ml.CacheConfig {
if b.flashAttention {
return ml.CacheConfig{CachePadding: 256, MaskDType: ml.DTypeF16, MaskBatchPadding: C.GGML_KQ_MASK_PAD}
} else {
return ml.CacheConfig{CachePadding: 32, PermutedV: true}
}
}
type Context struct {
b *Backend
ctx *C.struct_ggml_context
@@ -256,12 +267,16 @@ type Context struct {
nodes int
}
func (c *Context) Forward(t ml.Tensor) {
func (c *Context) Forward(tensors ...ml.Tensor) ml.Context {
if c.graph == nil {
c.graph = C.ggml_new_graph_custom(c.ctx, C.size_t(c.nodes), false)
}
C.ggml_build_forward_expand(c.graph, t.(*Tensor).t)
for _, tensor := range tensors {
C.ggml_build_forward_expand(c.graph, tensor.(*Tensor).t)
}
return c
}
func (c *Context) Compute(tensors ...ml.Tensor) {
@@ -296,7 +311,7 @@ func shapeToGGML(shape []int) *C.int64_t {
return &sh[0]
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
func newTensor(ctx Context, dtype ml.DType, zero bool, shape []int) ml.Tensor {
if len(shape) < 1 || len(shape) > 4 {
panic("unsupported number of dimensions")
}
@@ -310,19 +325,29 @@ func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
var t *C.struct_ggml_tensor
switch dtype {
case ml.DTypeF32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F32, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeF16:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_F16, C.int(len(shape)), shapeToGGML(shape))
case ml.DTypeI32:
t = C.ggml_new_tensor(c.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
t = C.ggml_new_tensor(ctx.ctx, C.GGML_TYPE_I32, C.int(len(shape)), shapeToGGML(shape))
default:
panic("unsupported dtype")
}
b := C.ggml_backend_alloc_buffer(c.backend, C.ggml_nbytes(t))
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_set_zero(t)
return &Tensor{t: t}
if zero {
C.ggml_set_zero(t)
}
return &Tensor{b: ctx.b, t: t}
}
func (c Context) Empty(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, false, shape)
}
func (c Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor {
return newTensor(c, dtype, true, shape)
}
func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype uint32) (ml.Tensor, error) {
@@ -331,7 +356,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
if n == 0 {
var shape C.int64_t = 0
t := C.ggml_new_tensor(ctx.ctx, dtype, 1, &shape)
return &Tensor{t: t}, nil
return &Tensor{b: ctx.b, t: t}, nil
}
for _, v := range shape {
@@ -346,7 +371,7 @@ func fromSlice[S ~[]E, E float32 | int32](ctx Context, s S, shape []int, dtype u
b := C.ggml_backend_alloc_buffer(ctx.backend, C.ggml_nbytes(t))
C.ggml_backend_tensor_alloc(b, t, C.ggml_backend_buffer_get_base(b))
C.ggml_backend_tensor_set(t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t))
return &Tensor{t: t}, nil
return &Tensor{b: ctx.b, t: t}, nil
}
func (c Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) {
@@ -364,6 +389,7 @@ func (c *Context) Close() {
}
type Tensor struct {
b *Backend
t *C.struct_ggml_tensor
sync func()
}
@@ -430,6 +456,7 @@ func (t *Tensor) DType() ml.DType {
func (t *Tensor) Add(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_add(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -444,24 +471,28 @@ func (t *Tensor) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor {
func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_concat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(dim)),
}
}
func (t *Tensor) Contiguous(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cont(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Mul(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Mulmat(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_mul_mat(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -471,12 +502,13 @@ func (t *Tensor) MulmatFullPrec(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
C.ggml_mul_mat_set_prec(mul, C.GGML_PREC_F32)
return &Tensor{
b: t.b,
t: mul,
}
}
func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor {
tt := (&Tensor{t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
tt := (&Tensor{b: t.b, t: C.ggml_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
if b != nil {
tt = tt.Add(ctx, b)
}
@@ -485,7 +517,7 @@ func (t *Tensor) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tenso
}
func (t *Tensor) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor {
return (&Tensor{t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
return (&Tensor{b: t.b, t: C.ggml_rms_norm(ctx.(*Context).ctx, t.t, C.float(eps))}).Mul(ctx, w)
}
func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
@@ -494,6 +526,7 @@ func (t *Tensor) Pad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_pad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -504,18 +537,21 @@ func (t *Tensor) Permute(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_permute(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
func (t *Tensor) Rows(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_get_rows(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_cpy(ctx.(*Context).ctx, t.t, t2.(*Tensor).t),
}
}
@@ -524,18 +560,22 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_reshape_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0])),
}
case 2:
return &Tensor{
b: t.b,
t: C.ggml_reshape_2d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1])),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_reshape_3d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2])),
}
case 4:
return &Tensor{
b: t.b,
t: C.ggml_reshape_4d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.int64_t(shape[1]), C.int64_t(shape[2]), C.int64_t(shape[3])),
}
default:
@@ -545,18 +585,21 @@ func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor {
func (t *Tensor) Scale(ctx ml.Context, s float64) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_scale(ctx.(*Context).ctx, t.t, (C.float)(s)),
}
}
func (t *Tensor) Softmax(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_soft_max(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Tanh(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_tanh_inplace(ctx.(*Context).ctx, t.t),
}
}
@@ -567,6 +610,7 @@ func (t *Tensor) Unpad(ctx ml.Context, shape ...int) ml.Tensor {
}
return &Tensor{
b: t.b,
t: C.ggml_unpad(ctx.(*Context).ctx, t.t, C.int(shape[0]), C.int(shape[1]), C.int(shape[2]), C.int(shape[3])),
}
}
@@ -575,10 +619,12 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
switch len(shape) {
case 1:
return &Tensor{
b: t.b,
t: C.ggml_view_1d(ctx.(*Context).ctx, t.t, C.int64_t(shape[0]), C.size_t(offset)),
}
case 3:
return &Tensor{
b: t.b,
t: C.ggml_view_2d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]),
C.size_t(shape[1]),
@@ -586,6 +632,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 5:
return &Tensor{
b: t.b,
t: C.ggml_view_3d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]),
C.size_t(shape[1]), C.size_t(shape[3]),
@@ -593,6 +640,7 @@ func (t *Tensor) View(ctx ml.Context, offset int, shape ...int) ml.Tensor {
}
case 7:
return &Tensor{
b: t.b,
t: C.ggml_view_4d(ctx.(*Context).ctx, t.t,
C.int64_t(shape[0]), C.int64_t(shape[2]), C.int64_t(shape[4]), C.int64_t(shape[6]),
C.size_t(shape[1]), C.size_t(shape[3]), C.size_t(shape[5]),
@@ -609,7 +657,7 @@ const (
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim uint32, ropeBase, ropeScale float32) ml.Tensor {
if ropeFactors == nil {
ropeFactors = &Tensor{}
ropeFactors = &Tensor{b: t.b}
}
dequant := t.t
@@ -618,6 +666,7 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
}
return &Tensor{
b: t.b,
t: C.ggml_rope_ext(
ctx.(*Context).ctx, dequant, positionIDs.(*Tensor).t, ropeFactors.(*Tensor).t,
C.int(ropeDim),
@@ -635,18 +684,21 @@ func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDi
func (t *Tensor) GELU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_gelu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) SILU(ctx ml.Context) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_silu_inplace(ctx.(*Context).ctx, t.t),
}
}
func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor {
return &Tensor{
b: t.b,
t: C.ggml_conv_2d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int(s0), C.int(s1), C.int(p0), C.int(p1), C.int(d0), C.int(d1)),
}
}
@@ -657,13 +709,25 @@ func (t *Tensor) ScaledDotProductAttention(ctx ml.Context, key, value, mask ml.T
kqMask = mask.(*Tensor).t
}
kq := key.MulmatFullPrec(ctx, t)
kq = &Tensor{
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
query := t.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
if t.b.flashAttention {
value = value.Permute(ctx, 0, 2, 1, 3)
kqv := C.ggml_flash_attn_ext(ctx.(*Context).ctx, query.(*Tensor).t, key.(*Tensor).t, value.(*Tensor).t, kqMask, C.float(scale), 0, 0)
C.ggml_flash_attn_ext_set_prec(kqv, C.GGML_PREC_F32)
return &Tensor{b: t.b, t: kqv}
} else {
kq := key.MulmatFullPrec(ctx, query)
kq = &Tensor{
b: t.b,
t: C.ggml_soft_max_ext(ctx.(*Context).ctx, kq.(*Tensor).t, kqMask, C.float(scale), 0),
}
kqv := value.Mulmat(ctx, kq)
return kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
}
}
func (b *Backend) SystemInfo() string {

View File

@@ -10,6 +10,8 @@ package ggml
import "C"
import (
"context"
"fmt"
"log/slog"
"os"
"path/filepath"
@@ -22,21 +24,14 @@ import (
)
func init() {
C.ggml_log_set((C.ggml_log_callback)(C.sink), nil)
C.ggml_log_set(C.ggml_log_callback(C.sink), nil)
}
//export sink
func sink(level C.int, text *C.char, _ unsafe.Pointer) {
msg := strings.TrimSpace(C.GoString(text))
switch level {
case C.GGML_LOG_LEVEL_DEBUG:
slog.Debug(msg)
case C.GGML_LOG_LEVEL_INFO:
slog.Info(msg)
case C.GGML_LOG_LEVEL_WARN:
slog.Warn(msg)
case C.GGML_LOG_LEVEL_ERROR:
slog.Error(msg)
// slog levels zeros INFO and are multiples of 4
if slog.Default().Enabled(context.TODO(), slog.Level(int(level-C.GGML_LOG_LEVEL_INFO)*4)) {
fmt.Fprint(os.Stderr, C.GoString(text))
}
}

View File

@@ -3,6 +3,7 @@ package nn
import (
"fmt"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
)
@@ -11,40 +12,50 @@ import (
//
// Parameters:
// - ctx: Context for tensor operations
// - query: Query tensor (Q) with shape [d_k, seq_len_q, heads]
// - key: Key tensor (K) with shape [d_k, seq_len_k, kv_heads]
// - value: Value tensor (V) with shape [seq_len_k, d_v, kv_heads]
// - mask: Optional attention mask that is added to the attention score. If
// provided, should broadcast to [seq_len_k, seq_len_q, heads]
// - query: Query tensor (Q) with shape [d_k, heads, seq_len_q]
// - key: Key tensor (K) with shape [d_k, kv_heads, seq_len_k], can be nil to read from cache only
// - value: Value tensor (V) with shape [d_v, kv_heads, seq_len_k], can be nil to read from cache only
// - scale: Scaling factor, typically 1/√d_k where d_k is the key dimension
// - cache: KV cache to store key/value and get past history, can be nil to only use provided key/value
//
// Returns:
//
// Attention output with shape [d_v, heads, seq_len_q]
func Attention(ctx ml.Context, query, key, value, mask ml.Tensor, scale float64) ml.Tensor {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
func Attention(ctx ml.Context, query, key, value ml.Tensor, scale float64, cache kvcache.Cache) ml.Tensor {
if key != nil && value != nil {
if query.Dim(0) != key.Dim(0) {
panic(fmt.Errorf("d_k in attention operation does not match between query(%v) and key(%v)", query.Dim(0), key.Dim(0)))
}
if key.Dim(1) != value.Dim(1) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(1)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if cache != nil {
cache.Put(ctx, key, value)
}
} else if cache == nil {
panic("key & value tensors must be provided if cache is nil")
}
if mask != nil && query.Dim(1) != mask.Dim(1) {
panic(fmt.Errorf("seq_len_q in attention operation does not match between query(%v) and mask(%v)", query.Dim(1), mask.Dim(1)))
var mask ml.Tensor
if cache != nil {
key, value, mask = cache.Get(ctx)
}
if key.Dim(1) != value.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and value(%v)", key.Dim(1), value.Dim(0)))
}
if mask != nil && key.Dim(1) != mask.Dim(0) {
panic(fmt.Errorf("seq_len_k in attention operation does not match between key(%v) and mask(%v)", key.Dim(1), mask.Dim(0)))
}
if key.Dim(2) != value.Dim(2) {
panic(fmt.Errorf("kv_heads in attention operation does not match between key(%v) and value(%v)", key.Dim(2), value.Dim(2)))
}
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok {
// Only use the fast SDPA implementation if we have a cache, since that's what
// will do any expected backend-specific transformations for us
if sdpa, ok := query.(ml.ScaledDotProductAttention); ok && cache != nil {
return sdpa.ScaledDotProductAttention(ctx, key, value, mask, scale)
} else {
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scale)

View File

@@ -248,8 +248,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) {
return nil, err
}
ctx.Forward(t)
ctx.Compute(t)
ctx.Forward(t).Compute(t)
return t, nil
}

View File

@@ -37,7 +37,9 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
Layers: make([]Layer, c.Uint("block_count")),
@@ -79,15 +81,8 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
v := sa.Value.Forward(ctx, hiddenState)
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, k, v)
k, v, mask := cache.Get(ctx)
q = q.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
k = k.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
v = v.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
kqv := nn.Attention(ctx, q, k, v, mask, scaleFactor)
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, kqv)

View File

@@ -33,7 +33,9 @@ func New(c ml.Config) (model.Model, error) {
Types: c.Uints("tokenizer.ggml.token_type"),
Merges: c.Strings("tokenizer.ggml.merges"),
BOS: int32(c.Uint("tokenizer.ggml.bos_token_id")),
AddBOS: c.Bool("tokenizer.ggml.add_bos_token", true),
EOS: int32(c.Uint("tokenizer.ggml.eos_token_id")),
AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false),
},
),
ImageProcessor: newImageProcessor(c),
@@ -41,7 +43,9 @@ func New(c ml.Config) (model.Model, error) {
TextModel: newTextModel(c),
}
m.Cache = kvcache.NewWrapperCache(kvcache.NewEncoderCache(), kvcache.NewCausalCache(m.TextModel.Shift))
encoderCache := kvcache.NewEncoderCache()
encoderCache.SetConfig(ml.CacheConfig{})
m.Cache = kvcache.NewWrapperCache(encoderCache, kvcache.NewCausalCache(m.TextModel.Shift))
return &m, nil
}

View File

@@ -31,22 +31,15 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
value := sa.Value.Forward(ctx, hiddenState)
value = value.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
cache.Put(ctx, key, value)
key, value, mask := cache.Get(ctx)
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
attention := nn.Attention(ctx, query, key, value, scaleFactor, cache)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return sa.Output.Forward(ctx, attention)
}
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
// This will only get called for layers in the cache, which are just the self attention layers
// This will only get called for layers in the causal cache, which are just the self attention layers
return key.RoPE(ctx, shift, m.RopeFactors, m.ropeDim, m.ropeBase, m.ropeScale), nil
}
@@ -107,7 +100,7 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
query = ca.QueryNorm.Forward(ctx, query, opts.eps)
var key, value, mask ml.Tensor
var key, value ml.Tensor
if crossAttentionStates != nil {
numVisionTokens, numTiles := crossAttentionStates.Dim(1), crossAttentionStates.Dim(2)
@@ -119,16 +112,23 @@ func (ca *TextCrossAttention) Forward(ctx ml.Context, hiddenState, crossAttentio
value = value.Reshape(ctx, headDim, opts.numKVHeads, numVisionTokens*numTiles)
cache.Put(ctx, key, value)
} else {
key, value, mask = cache.Get(ctx)
}
query = query.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
key = key.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
key, value, _ = cache.Get(ctx)
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
attention := nn.Attention(ctx, query, key, value, mask, scaleFactor)
query = query.Permute(ctx, 0, 2, 1, 3)
key = key.Permute(ctx, 0, 2, 1, 3)
value = value.Permute(ctx, 1, 2, 0, 3).Contiguous(ctx)
kq := key.MulmatFullPrec(ctx, query)
kq = kq.Scale(ctx, scaleFactor)
kq = kq.Softmax(ctx)
kqv := value.Mulmat(ctx, kq)
attention := kqv.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx)
attention = attention.Reshape(ctx, opts.hiddenSize, batchSize)
return ca.Output.Forward(ctx, attention)

View File

@@ -30,7 +30,8 @@ type Vocabulary struct {
Scores []uint32
Merges []string
BOS, EOS int32
BOS, EOS int32
AddBOS, AddEOS bool
specialOnce sync.Once
special []string
@@ -281,6 +282,26 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
}
}
if len(ids) > 0 {
if bpe.vocab.AddBOS {
if ids[0] == bpe.vocab.BOS {
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
}
slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS)
ids = append([]int32{bpe.vocab.BOS}, ids...)
}
if bpe.vocab.AddEOS {
if ids[len(ids)-1] == bpe.vocab.EOS {
slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS)
}
slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS)
ids = append(ids, bpe.vocab.EOS)
}
}
return ids, nil
}

View File

@@ -915,7 +915,6 @@ func Execute(args []string) error {
level := slog.LevelInfo
if *verbose {
level = slog.LevelDebug
llama.EnableDebug()
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
@@ -944,12 +943,11 @@ func Execute(args []string) error {
var tensorSplitFloats []float32
if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
tensorSplitFloats = make([]float32, 0, len(stringFloats))
for _, s := range stringFloats {
splits := strings.Split(*tensorSplit, ",")
tensorSplitFloats = make([]float32, len(splits))
for i, s := range splits {
f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f))
tensorSplitFloats[i] = float32(f)
}
}
@@ -970,13 +968,14 @@ func Execute(args []string) error {
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
cancel()
return err
}
defer listener.Close()
@@ -996,6 +995,5 @@ func Execute(args []string) error {
return err
}
cancel()
return nil
}

View File

@@ -575,23 +575,11 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) {
return
}
sampler, err := sample.NewSampler(
req.Temperature,
req.TopK,
req.TopP,
req.MinP,
req.Seed,
)
if err != nil {
http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError)
return
}
seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{
numPredict: req.NumPredict,
stop: req.Stop,
numKeep: int32(req.NumKeep),
sampler: sampler,
sampler: sample.Greedy(), // TODO: add support for different samplers when performance is optimized
embedding: false,
})
if err != nil {
@@ -830,7 +818,7 @@ func Execute(args []string) error {
batchSize := fs.Int("batch-size", 512, "Batch size")
numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU")
mainGPU := fs.Int("main-gpu", 0, "Main GPU")
_ = fs.Bool("flash-attn", false, "Enable flash attention")
flashAttention := fs.Bool("flash-attn", false, "Enable flash attention")
kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size")
kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)")
port := fs.Int("port", 8080, "Port to expose the server on")
@@ -875,26 +863,25 @@ func Execute(args []string) error {
}
// TODO(jessegross): Parameters that need to be implemented:
// flash-attn
// no-mmap
// mlock
var tensorSplitFloats []float32
if *tensorSplit != "" {
stringFloats := regexp.MustCompile(",").Split(*tensorSplit, -1)
tensorSplitFloats = make([]float32, 0, len(stringFloats))
for _, s := range stringFloats {
splits := strings.Split(*tensorSplit, ",")
tensorSplitFloats = make([]float32, len(splits))
for i, s := range splits {
f, _ := strconv.ParseFloat(s, 32)
tensorSplitFloats = append(tensorSplitFloats, float32(f))
tensorSplitFloats[i] = float32(f)
}
}
params := ml.BackendParams{
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
NumThreads: *threads,
NumGPULayers: *numGPULayers,
MainGPU: *mainGPU,
TensorSplit: tensorSplitFloats,
FlashAttention: *flashAttention,
}
server.ready.Add(1)
@@ -903,13 +890,14 @@ func Execute(args []string) error {
server.cond = sync.NewCond(&server.mu)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
go server.run(ctx)
addr := "127.0.0.1:" + strconv.Itoa(*port)
listener, err := net.Listen("tcp", addr)
if err != nil {
fmt.Println("Listen error:", err)
cancel()
return err
}
defer listener.Close()
@@ -929,6 +917,5 @@ func Execute(args []string) error {
return err
}
cancel()
return nil
}

View File

@@ -54,53 +54,42 @@ func (s weighted) Sample(logits []float32) (int32, error) {
if idx, ok := w.Take(); ok {
return int32(indices[idx]), nil
}
return -1, errors.New("weighed sampler failed, no valid token found")
return -1, errors.New("weighted sampler failed, no valid token found")
}
type greedy struct {
transforms []Transform
}
func Greedy(transforms ...Transform) Sampler {
return greedy{transforms: transforms}
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
// Sample returns the index of the maximum value in logits.
func (s greedy) Sample(logits []float32) (int32, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
if len(logits) == 0 {
return -1, errors.New("no logits provided for greedy sampling")
}
for _, t := range s.transforms {
logits64 = t.Apply(logits64)
}
var maxIdx int
var maxLogit float64
for i, logit := range logits64 {
if logit > maxLogit {
maxLogit = logit
maxIdx := 0
for i := range logits {
if logits[i] > logits[maxIdx] {
maxIdx = i
}
}
if maxLogit == math.Inf(-1) {
return -1, errors.New("no valid logits found for greedy sampling")
}
return int32(maxIdx), nil
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) {
transforms := []Transform{}
if temperature == 0 {
return Greedy(), nil
}
if temperature < 0 || temperature > 2 {
return nil, errors.New("temperature must be between 0 and 2")
}
if temperature != 0 {
transforms = append(transforms, Temperature(temperature))
}
transforms := []Transform{Temperature(temperature)}
if topK != 0 {
if topK <= 0 {
@@ -123,15 +112,7 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
transforms = append(transforms, MinP(minP))
}
if len(transforms) == 0 {
return nil, errors.New("at least one transform is required")
}
if temperature == 0 {
return Greedy(transforms...), nil
}
if seed != 0 {
if seed >= 0 {
seed64 := uint64(seed)
return Weighted(&seed64, transforms...), nil
}

View File

@@ -66,32 +66,15 @@ func TestSample(t *testing.T) {
callOrder: &callOrder,
}
got, err := Greedy(mock1, mock2, mock3).Sample(input)
_, err := Weighted(nil, mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
want := int32(3) // Greedy sampler should pick highest logit
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
callOrder = nil
_, err = Weighted(nil, mock1, mock2, mock3).Sample(input)
if err != nil {
t.Error(err)
return
}
wantOrder = []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
}
func TestNewSampler(t *testing.T) {
@@ -105,8 +88,9 @@ func TestNewSampler(t *testing.T) {
wantErr bool
}{
{
name: "no transforms",
wantErr: true,
name: "no transforms",
// temperature is 0, so greedy should be used
wantErr: false,
},
{
name: "temperature",
@@ -124,49 +108,52 @@ func TestNewSampler(t *testing.T) {
wantErr: true,
},
{
name: "top k",
topK: 10,
wantErr: false,
name: "top k",
topK: 10,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top k negative",
topK: -1,
wantErr: true,
name: "invalid top k negative",
topK: -1,
temperature: 0.8,
wantErr: true,
},
{
name: "top p",
topP: 0.9,
wantErr: false,
name: "top p",
topP: 0.9,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid top p negative",
topP: -0.1,
wantErr: true,
name: "invalid top p negative",
topP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid top p one",
topP: 1.0,
wantErr: true,
name: "invalid top p one",
topP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "min p",
minP: 0.2,
wantErr: false,
name: "min p",
minP: 0.2,
temperature: 0.8,
wantErr: false,
},
{
name: "invalid min p negative",
minP: -0.1,
wantErr: true,
name: "invalid min p negative",
minP: -0.1,
temperature: 0.8,
wantErr: true,
},
{
name: "invalid min p one",
minP: 1.0,
wantErr: true,
},
{
name: "seed",
seed: 42,
wantErr: true, // seed alone is not valid without other transforms
name: "invalid min p one",
minP: 1.0,
temperature: 0.8,
wantErr: true,
},
{
name: "default values",
@@ -184,7 +171,7 @@ func TestNewSampler(t *testing.T) {
topP: 0.0,
minP: 0.0,
seed: 0,
wantErr: true, // all zeroes means no transforms
wantErr: false, // all zeroes means no transforms
},
{
name: "all transforms",
@@ -216,7 +203,7 @@ func BenchmarkSample(b *testing.B) {
}
samplers := map[string]Sampler{
"Greedy": Greedy(transforms...),
"Greedy": Greedy(),
"Weighted": Weighted(nil, transforms...),
}

View File

@@ -8,6 +8,7 @@ import (
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
@@ -34,6 +35,7 @@ var (
errOnlyGGUFSupported = errors.New("supplied file was not in GGUF format")
errUnknownType = errors.New("unknown type")
errNeitherFromOrFiles = errors.New("neither 'from' or 'files' was specified")
errFilePath = errors.New("file path must be relative")
)
func (s *Server) CreateHandler(c *gin.Context) {
@@ -46,6 +48,13 @@ func (s *Server) CreateHandler(c *gin.Context) {
return
}
for v := range r.Files {
if !fs.ValidPath(v) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errFilePath.Error()})
return
}
}
name := model.ParseName(cmp.Or(r.Model, r.Name))
if !name.IsValid() {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": errtypes.InvalidModelNameErrMsg})
@@ -104,7 +113,7 @@ func (s *Server) CreateHandler(c *gin.Context) {
if r.Adapters != nil {
adapterLayers, err = convertModelFromFiles(r.Adapters, baseLayers, true, fn)
if err != nil {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType} {
for _, badReq := range []error{errNoFilesProvided, errOnlyOneAdapterSupported, errOnlyGGUFSupported, errUnknownType, errFilePath} {
if errors.Is(err, badReq) {
ch <- gin.H{"error": err.Error(), "status": http.StatusBadRequest}
return
@@ -221,8 +230,22 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
return nil, err
}
defer os.RemoveAll(tmpDir)
// Set up a root to validate paths
root, err := os.OpenRoot(tmpDir)
if err != nil {
return nil, err
}
defer root.Close()
for fp, digest := range files {
if !fs.ValidPath(fp) {
return nil, fmt.Errorf("%w: %s", errFilePath, fp)
}
if _, err := root.Stat(fp); err != nil && !errors.Is(err, fs.ErrNotExist) {
// Path is likely outside the root
return nil, fmt.Errorf("%w: %s: %s", errFilePath, err, fp)
}
blobPath, err := GetBlobsPath(digest)
if err != nil {
return nil, err
@@ -270,6 +293,7 @@ func convertFromSafetensors(files map[string]string, baseLayers []*layerGGML, is
if err != nil {
return nil, err
}
defer bin.Close()
f, _, err := ggml.Decode(bin, 0)
if err != nil {

106
server/create_test.go Normal file
View File

@@ -0,0 +1,106 @@
package server
import (
"bytes"
"encoding/binary"
"errors"
"os"
"path/filepath"
"strings"
"testing"
"github.com/ollama/ollama/api"
)
func TestConvertFromSafetensors(t *testing.T) {
t.Setenv("OLLAMA_MODELS", t.TempDir())
// Helper function to create a new layer and return its digest
makeTemp := func(content string) string {
l, err := NewLayer(strings.NewReader(content), "application/octet-stream")
if err != nil {
t.Fatalf("Failed to create layer: %v", err)
}
return l.Digest
}
// Create a safetensors compatible file with empty JSON content
var buf bytes.Buffer
headerSize := int64(len("{}"))
binary.Write(&buf, binary.LittleEndian, headerSize)
buf.WriteString("{}")
model := makeTemp(buf.String())
config := makeTemp(`{
"architectures": ["LlamaForCausalLM"],
"vocab_size": 32000
}`)
tokenizer := makeTemp(`{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [
{
"id": 0,
"content": "<|endoftext|>",
"single_word": false,
"lstrip": false,
"rstrip": false,
"normalized": false,
"special": true
}
]
}`)
tests := []struct {
name string
filePath string
wantErr error
}{
// Invalid
{
name: "InvalidRelativePathShallow",
filePath: filepath.Join("..", "file.safetensors"),
wantErr: errFilePath,
},
{
name: "InvalidRelativePathDeep",
filePath: filepath.Join("..", "..", "..", "..", "..", "..", "data", "file.txt"),
wantErr: errFilePath,
},
{
name: "InvalidNestedPath",
filePath: filepath.Join("dir", "..", "..", "..", "..", "..", "other.safetensors"),
wantErr: errFilePath,
},
{
name: "AbsolutePathOutsideRoot",
filePath: filepath.Join(os.TempDir(), "model.safetensors"),
wantErr: errFilePath, // Should fail since it's outside tmpDir
},
{
name: "ValidRelativePath",
filePath: "model.safetensors",
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create the minimum required file map for convertFromSafetensors
files := map[string]string{
tt.filePath: model,
"config.json": config,
"tokenizer.json": tokenizer,
}
_, err := convertFromSafetensors(files, nil, false, func(resp api.ProgressResponse) {})
if (tt.wantErr == nil && err != nil) ||
(tt.wantErr != nil && err == nil) ||
(tt.wantErr != nil && !errors.Is(err, tt.wantErr)) {
t.Errorf("convertFromSafetensors() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -279,6 +279,18 @@ func (c *DiskCache) Get(d Digest) (Entry, error) {
// It returns an error if either the name or digest is invalid, or if link
// creation encounters any issues.
func (c *DiskCache) Link(name string, d Digest) error {
// TODO(bmizerany): Move link handling from cache to registry.
//
// We originally placed links in the cache due to its storage
// knowledge. However, the registry likely offers better context for
// naming concerns, and our API design shouldn't be tightly coupled to
// our on-disk format.
//
// Links work effectively when independent from physical location -
// they can reference content with matching SHA regardless of storage
// location. In an upcoming change, we plan to shift this
// responsibility to the registry where it better aligns with the
// system's conceptual model.
manifest, err := c.manifestPath(name)
if err != nil {
return err
@@ -304,21 +316,19 @@ func (c *DiskCache) Link(name string, d Digest) error {
return c.copyNamedFile(manifest, f, d, info.Size())
}
// Unlink removes the any link for name. If the link does not exist, nothing
// happens, and no error is returned.
//
// It returns an error if the name is invalid or if the link removal encounters
// any issues.
func (c *DiskCache) Unlink(name string) error {
// Unlink unlinks the manifest by name from the cache. If the name is not
// found. If a manifest is removed ok will be true, otherwise false. If an
// error occurs, it returns ok false, and the error.
func (c *DiskCache) Unlink(name string) (ok bool, _ error) {
manifest, err := c.manifestPath(name)
if err != nil {
return err
return false, err
}
err = os.Remove(manifest)
if errors.Is(err, fs.ErrNotExist) {
return nil
return false, nil
}
return err
return true, err
}
// GetFile returns the absolute path to the file, in the cache, for the given

View File

@@ -13,7 +13,7 @@ import (
"testing"
"time"
"github.com/ollama/ollama/server/internal/internal/testutil"
"github.com/ollama/ollama/server/internal/testutil"
)
func init() {
@@ -479,8 +479,11 @@ func testManifestNameReuse(t *testing.T) {
}
// relink with different case
err = c.Unlink("h/n/m:t")
unlinked, err := c.Unlink("h/n/m:t")
check(err)
if !unlinked {
t.Fatal("expected unlinked")
}
err = c.Link("h/n/m:T", d1)
check(err)

View File

@@ -86,7 +86,7 @@ func useCaseInsensitiveTempDir(t *testing.T) bool {
// link to docs on that topic.
lines := strings.Split(volumeHint, "\n")
for _, line := range lines {
t.Log(line)
t.Skip(line)
}
}
return false

View File

@@ -19,10 +19,12 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"net/http"
"os"
"path/filepath"
"runtime"
"slices"
"strconv"
"strings"
"sync/atomic"
@@ -52,7 +54,7 @@ var (
// ErrMissingModel is returned when the model part of a name is missing
// or invalid.
ErrNameInvalid = errors.New("invalid name; must be in the form {scheme://}{host/}{namespace/}[model]{:tag}{@digest}")
ErrNameInvalid = errors.New("invalid or missing name")
// ErrCached is passed to [Trace.PushUpdate] when a layer already
// exists. It is a non-fatal error and is never returned by [Registry.Push].
@@ -86,9 +88,23 @@ func DefaultCache() (*blob.DiskCache, error) {
return blob.Open(dir)
}
// Error is the standard error returned by Ollama APIs.
// Error is the standard error returned by Ollama APIs. It can represent a
// single or multiple error response.
//
// Single error responses have the following format:
//
// {"code": "optional_code","error":"error message"}
//
// Multiple error responses have the following format:
//
// {"errors": [{"code": "optional_code","message":"error message"}]}
//
// Note, that the error field is used in single error responses, while the
// message field is used in multiple error responses.
//
// In both cases, the code field is optional and may be empty.
type Error struct {
Status int `json:"-"`
Status int `json:"-"` // TODO(bmizerany): remove this
Code string `json:"code"`
Message string `json:"message"`
}
@@ -97,13 +113,34 @@ func (e *Error) Error() string {
return fmt.Sprintf("registry responded with status %d: %s %s", e.Status, e.Code, e.Message)
}
func (e *Error) LogValue() slog.Value {
return slog.GroupValue(
slog.Int("status", e.Status),
slog.String("code", e.Code),
slog.String("message", e.Message),
)
}
// UnmarshalJSON implements json.Unmarshaler.
func (e *Error) UnmarshalJSON(b []byte) error {
type E Error
var v struct{ Errors []E }
var v struct {
// Single error
Code string
Error string
// Multiple errors
Errors []E
}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
if v.Error != "" {
// Single error case
e.Code = v.Code
e.Message = v.Error
return nil
}
if len(v.Errors) == 0 {
return fmt.Errorf("no messages in error response: %s", string(b))
}
@@ -111,15 +148,23 @@ func (e *Error) UnmarshalJSON(b []byte) error {
return nil
}
// TODO(bmizerany): make configurable on [Registry]
var defaultName = func() names.Name {
n := names.Parse("ollama.com/library/_:latest")
const DefaultMask = "registry.ollama.ai/library/_:latest"
var defaultMask = func() names.Name {
n := names.Parse(DefaultMask)
if !n.IsFullyQualified() {
panic("default name is not fully qualified")
panic("default mask is not fully qualified")
}
return n
}()
// CompleteName returns a fully qualified name by merging the given name with
// the default mask. If the name is already fully qualified, it is returned
// unchanged.
func CompleteName(name string) string {
return names.Merge(names.Parse(name), defaultMask).String()
}
// Registry is a client for performing push and pull operations against an
// Ollama registry.
type Registry struct {
@@ -160,21 +205,38 @@ type Registry struct {
//
// It is only used when a layer is larger than [MaxChunkingThreshold].
MaxChunkSize int64
// Mask, if set, is the name used to convert non-fully qualified
// names to fully qualified names. If empty, the default mask
// ("registry.ollama.ai/library/_:latest") is used.
Mask string
}
// RegistryFromEnv returns a new Registry configured from the environment. The
func (r *Registry) parseName(name string) (names.Name, error) {
mask := defaultMask
if r.Mask != "" {
mask = names.Parse(r.Mask)
}
n := names.Merge(names.Parse(name), mask)
if !n.IsFullyQualified() {
return names.Name{}, fmt.Errorf("%w: %q", ErrNameInvalid, name)
}
return n, nil
}
// DefaultRegistry returns a new Registry configured from the environment. The
// key is read from $HOME/.ollama/id_ed25519, MaxStreams is set to the
// value of OLLAMA_REGISTRY_MAXSTREAMS, and ChunkingDirectory is set to the
// system's temporary directory.
//
// It returns an error if any configuration in the environment is invalid.
func RegistryFromEnv() (*Registry, error) {
func DefaultRegistry() (*Registry, error) {
home, err := os.UserHomeDir()
if err != nil {
return nil, err
}
keyPEM, err := os.ReadFile(filepath.Join(home, ".ollama/id_ed25519"))
if err != nil {
if err != nil && errors.Is(err, fs.ErrNotExist) {
return nil, err
}
@@ -194,42 +256,6 @@ func RegistryFromEnv() (*Registry, error) {
return &rc, nil
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// parseName parses name using [names.ParseExtended] and then merges the name with the
// default name, and checks that the name is fully qualified. If a digest is
// present, it parse and returns it with the other fields as their zero values.
//
// It returns an error if the name is not fully qualified, or if the digest, if
// any, is invalid.
//
// The scheme is returned as provided by [names.ParseExtended].
func parseName(s string) (scheme string, n names.Name, d blob.Digest, err error) {
scheme, n, ds := names.ParseExtended(s)
n = names.Merge(n, defaultName)
if ds != "" {
// Digest is present. Validate it.
d, err = blob.ParseDigest(ds)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
}
// The name check is deferred until after the digest check because we
// say that digests take precedence over names, and so should there
// errors when being parsed.
if !n.IsFullyQualified() {
return "", names.Name{}, blob.Digest{}, ErrNameInvalid
}
scheme = cmp.Or(scheme, "https")
return scheme, n, d, nil
}
func (r *Registry) maxStreams() int {
n := cmp.Or(r.MaxStreams, runtime.GOMAXPROCS(0))
@@ -249,13 +275,19 @@ func (r *Registry) maxChunkSize() int64 {
return cmp.Or(r.MaxChunkSize, DefaultMaxChunkSize)
}
type PushParams struct {
// From is an optional destination name for the model. If empty, the
// destination name is the same as the source name.
From string
}
// Push pushes the model with the name in the cache to the remote registry.
func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *PushParams) error {
if p == nil {
p = &PushParams{}
}
m, err := ResolveLocal(c, cmp.Or(p.From, name))
m, err := r.ResolveLocal(c, cmp.Or(p.From, name))
if err != nil {
return err
}
@@ -278,7 +310,7 @@ func (r *Registry) Push(ctx context.Context, c *blob.DiskCache, name string, p *
t := traceFromContext(ctx)
scheme, n, _, err := parseName(name)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
// This should never happen since ResolveLocal should have
// already validated the name.
@@ -372,7 +404,7 @@ func canRetry(err error) bool {
// typically slower than splitting the model up across layers, and is mostly
// utilized for layers of type equal to "application/vnd.ollama.image".
func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) error {
scheme, n, _, err := parseName(name)
scheme, n, _, err := r.parseNameExtended(name)
if err != nil {
return err
}
@@ -520,6 +552,16 @@ func (r *Registry) Pull(ctx context.Context, c *blob.DiskCache, name string) err
return c.Link(m.Name, md)
}
// Unlink is like [blob.DiskCache.Unlink], but makes name fully qualified
// before attempting to unlink the model.
func (r *Registry) Unlink(c *blob.DiskCache, name string) (ok bool, _ error) {
n, err := r.parseName(name)
if err != nil {
return false, err
}
return c.Unlink(n.String())
}
// Manifest represents a [ollama.com/manifest].
type Manifest struct {
Name string `json:"-"` // the canonical name of the model
@@ -588,10 +630,9 @@ type Layer struct {
Size int64 `json:"size"`
}
// ResolveLocal resolves a name to a Manifest in the local cache. The name is
// parsed using [names.ParseExtended] but the scheme is ignored.
func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := parseName(name)
// ResolveLocal resolves a name to a Manifest in the local cache.
func (r *Registry) ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
_, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@@ -617,7 +658,7 @@ func ResolveLocal(c *blob.DiskCache, name string) (*Manifest, error) {
// Resolve resolves a name to a Manifest in the remote registry.
func (r *Registry) Resolve(ctx context.Context, name string) (*Manifest, error) {
scheme, n, d, err := parseName(name)
scheme, n, d, err := r.parseNameExtended(name)
if err != nil {
return nil, err
}
@@ -800,3 +841,89 @@ func maybeUnexpectedEOF(err error) error {
}
return err
}
type publicError struct {
wrapped error
message string
}
func withPublicMessagef(err error, message string, args ...any) error {
return publicError{wrapped: err, message: fmt.Sprintf(message, args...)}
}
func (e publicError) Error() string { return e.message }
func (e publicError) Unwrap() error { return e.wrapped }
var supportedSchemes = []string{
"http",
"https",
"https+insecure",
}
var supportedSchemesMessage = fmt.Sprintf("supported schemes are %v", strings.Join(supportedSchemes, ", "))
// parseNameExtended parses and validates an extended name, returning the scheme, name,
// and digest.
//
// If the scheme is empty, scheme will be "https". If an unsupported scheme is
// given, [ErrNameInvalid] wrapped with a display friendly message is returned.
//
// If the digest is invalid, [ErrNameInvalid] wrapped with a display friendly
// message is returned.
//
// If the name is not, once merged with the mask, fully qualified,
// [ErrNameInvalid] wrapped with a display friendly message is returned.
func (r *Registry) parseNameExtended(s string) (scheme string, _ names.Name, _ blob.Digest, _ error) {
scheme, name, digest := splitExtended(s)
scheme = cmp.Or(scheme, "https")
if !slices.Contains(supportedSchemes, scheme) {
err := withPublicMessagef(ErrNameInvalid, "unsupported scheme: %q: %s", scheme, supportedSchemesMessage)
return "", names.Name{}, blob.Digest{}, err
}
var d blob.Digest
if digest != "" {
var err error
d, err = blob.ParseDigest(digest)
if err != nil {
err = withPublicMessagef(ErrNameInvalid, "invalid digest: %q", digest)
return "", names.Name{}, blob.Digest{}, err
}
if name == "" {
// We have can resolve a manifest from a digest only,
// so skip name validation and return the scheme and
// digest.
return scheme, names.Name{}, d, nil
}
}
n, err := r.parseName(name)
if err != nil {
return "", names.Name{}, blob.Digest{}, err
}
return scheme, n, d, nil
}
// splitExtended splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
// model@digest
// @digest
func splitExtended(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
s = s[i+3:]
}
i = strings.LastIndex(s, "@")
if i >= 0 {
digest = s[i+1:]
s = s[:i]
}
return scheme, s, digest
}

View File

@@ -2,6 +2,7 @@ package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"errors"
@@ -21,7 +22,7 @@ import (
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/chunks"
"github.com/ollama/ollama/server/internal/internal/testutil"
"github.com/ollama/ollama/server/internal/testutil"
)
func TestManifestMarshalJSON(t *testing.T) {
@@ -37,20 +38,6 @@ func TestManifestMarshalJSON(t *testing.T) {
}
}
func link(c *blob.DiskCache, name string, manifest string) {
_, n, _, err := parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
var errRoundTrip = errors.New("forced roundtrip error")
type recordRoundTripper http.HandlerFunc
@@ -98,30 +85,45 @@ func newClient(t *testing.T, h http.HandlerFunc) (*Registry, *blob.DiskCache) {
}
}
r := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
link := func(name string, manifest string) {
n, err := r.parseName(name)
if err != nil {
panic(err)
}
d, err := c.Import(bytes.NewReader([]byte(manifest)), int64(len(manifest)))
if err != nil {
panic(err)
}
if err := c.Link(n.String(), d); err != nil {
panic(err)
}
}
commit := func(name string, layers ...*Layer) {
t.Helper()
data, err := json.Marshal(&Manifest{Layers: layers})
if err != nil {
t.Fatal(err)
}
link(c, name, string(data))
link(name, string(data))
}
link(c, "empty", "")
link("empty", "")
commit("zero")
commit("single", mklayer("exists"))
commit("multiple", mklayer("exists"), mklayer("present"))
commit("notfound", &Layer{Digest: blob.DigestFromBytes("notfound"), Size: int64(len("notfound"))})
commit("null", nil)
commit("sizemismatch", mklayer("exists"), &Layer{Digest: blob.DigestFromBytes("present"), Size: 499})
link(c, "invalid", "!!!!!")
link("invalid", "!!!!!")
rc := &Registry{
HTTPClient: &http.Client{
Transport: recordRoundTripper(h),
},
}
return rc, c
return r, c
}
func okHandler(w http.ResponseWriter, r *http.Request) {
@@ -144,29 +146,6 @@ func importBytes(t *testing.T, c *blob.DiskCache, data string) blob.Digest {
return d
}
func TestRegistryPushInvalidNames(t *testing.T) {
rc, c := newClient(t, nil)
cases := []struct {
name string
err error
}{
{"", ErrNameInvalid},
{"@", ErrNameInvalid},
{"@x", blob.ErrInvalidDigest},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
// Create a new registry and push a new image.
err := rc.Push(t.Context(), c, tt.name, nil)
if !errors.Is(err, tt.err) {
t.Errorf("err = %v; want %v", err, tt.err)
}
})
}
}
func withTraceUnexpected(ctx context.Context) (context.Context, *Trace) {
t := &Trace{Update: func(*Layer, int64, error) { panic("unexpected") }}
return WithTrace(ctx, t), t
@@ -385,7 +364,7 @@ func TestRegistryPullNotCached(t *testing.T) {
})
// Confirm that the layer does not exist locally
_, err := ResolveLocal(c, "model")
_, err := rc.ResolveLocal(c, "model")
checkNotExist(t, err)
_, err = c.Get(d)
@@ -396,7 +375,7 @@ func TestRegistryPullNotCached(t *testing.T) {
mw, err := rc.Resolve(t.Context(), "model")
check(err)
mg, err := ResolveLocal(c, "model")
mg, err := rc.ResolveLocal(c, "model")
check(err)
if !reflect.DeepEqual(mw, mg) {
t.Errorf("mw = %v; mg = %v", mw, mg)
@@ -621,7 +600,7 @@ func TestInsecureSkipVerify(t *testing.T) {
}))
defer s.Close()
const name = "ollama.com/library/insecure"
const name = "library/insecure"
var rc Registry
url := fmt.Sprintf("https://%s/%s", s.Listener.Addr(), name)
@@ -654,3 +633,184 @@ func TestCanRetry(t *testing.T) {
}
}
}
func TestErrorUnmarshal(t *testing.T) {
cases := []struct {
name string
data string
want *Error
wantErr bool
}{
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors empty",
data: `{"errors":[]}`,
wantErr: true,
},
{
name: "errors single",
data: `{"errors":[{"code":"blob_unknown"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "errors multiple",
data: `{"errors":[{"code":"blob_unknown"},{"code":"blob_error"}]}`,
want: &Error{Code: "blob_unknown", Message: ""},
},
{
name: "error empty",
data: `{"error":""}`,
wantErr: true,
},
{
name: "error very empty",
data: `{}`,
wantErr: true,
},
{
name: "error message",
data: `{"error":"message", "code":"code"}`,
want: &Error{Code: "code", Message: "message"},
},
{
name: "invalid value",
data: `{"error": 1}`,
wantErr: true,
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
var got Error
err := json.Unmarshal([]byte(tt.data), &got)
if err != nil {
if tt.wantErr {
return
}
t.Errorf("Unmarshal() error = %v", err)
// fallthrough and check got
}
if tt.want == nil {
tt.want = &Error{}
}
if !reflect.DeepEqual(got, *tt.want) {
t.Errorf("got = %v; want %v", got, *tt.want)
}
})
}
}
// TestParseNameErrors tests that parseName returns errors messages with enough
// detail for users to debug naming issues they may encounter. Previous to this
// test, the error messages were not very helpful and each problem was reported
// as the same message.
//
// It is only for testing error messages, not that all invalids and valids are
// covered. Those are in other tests for names.Name and blob.Digest.
func TestParseNameExtendedErrors(t *testing.T) {
cases := []struct {
name string
err error
want string
}{}
var r Registry
for _, tt := range cases {
_, _, _, err := r.parseNameExtended(tt.name)
if !errors.Is(err, tt.err) {
t.Errorf("[%s]: err = %v; want %v", tt.name, err, tt.err)
}
if err != nil && !strings.Contains(err.Error(), tt.want) {
t.Errorf("[%s]: err =\n\t%v\nwant\n\t%v", tt.name, err, tt.want)
}
}
}
func TestParseNameExtended(t *testing.T) {
cases := []struct {
in string
scheme string
name string
digest string
err string
}{
{in: "http://m", scheme: "http", name: "m"},
{in: "https+insecure://m", scheme: "https+insecure", name: "m"},
{in: "http+insecure://m", err: "unsupported scheme"},
{in: "http://m@sha256:1111111111111111111111111111111111111111111111111111111111111111", scheme: "http", name: "m", digest: "sha256:1111111111111111111111111111111111111111111111111111111111111111"},
{in: "", err: "invalid or missing name"},
{in: "m", scheme: "https", name: "m"},
{in: "://", err: "invalid or missing name"},
{in: "@sha256:deadbeef", err: "invalid digest"},
{in: "@sha256:deadbeef@sha256:deadbeef", err: "invalid digest"},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
var r Registry
scheme, n, digest, err := r.parseNameExtended(tt.in)
if err != nil {
if tt.err == "" {
t.Errorf("err = %v; want nil", err)
} else if !strings.Contains(err.Error(), tt.err) {
t.Errorf("err = %v; want %q", err, tt.err)
}
} else if tt.err != "" {
t.Errorf("err = nil; want %q", tt.err)
}
if err == nil && !n.IsFullyQualified() {
t.Errorf("name = %q; want fully qualified", n)
}
if scheme != tt.scheme {
t.Errorf("scheme = %q; want %q", scheme, tt.scheme)
}
// smoke-test name is superset of tt.name
if !strings.Contains(n.String(), tt.name) {
t.Errorf("name = %q; want %q", n, tt.name)
}
tt.digest = cmp.Or(tt.digest, (&blob.Digest{}).String())
if digest.String() != tt.digest {
t.Errorf("digest = %q; want %q", digest, tt.digest)
}
})
}
}
func TestUnlink(t *testing.T) {
t.Run("found by name", func(t *testing.T) {
rc, c := newClient(t, nil)
// confirm linked
_, err := rc.ResolveLocal(c, "single")
if err != nil {
t.Errorf("unexpected error: %v", err)
}
// unlink
_, err = rc.Unlink(c, "single")
testutil.Check(t, err)
// confirm unlinked
_, err = rc.ResolveLocal(c, "single")
if !errors.Is(err, fs.ErrNotExist) {
t.Errorf("err = %v; want fs.ErrNotExist", err)
}
})
t.Run("not found by name", func(t *testing.T) {
rc, c := newClient(t, nil)
ok, err := rc.Unlink(c, "manifestNotFound")
if err != nil {
t.Fatal(err)
}
if ok {
t.Error("expected not found")
}
})
}

View File

@@ -86,6 +86,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
return nil, err
}
endOfHeader := 8 + headerSize // 8 bytes for header size plus the header itself
// TODO(bmizerany): do something with metadata? This could be another
// header read if needed. We also need to figure out if the metadata is
// present in only one .safetensors file or if each file may have their
@@ -95,7 +97,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
tt := make([]*Tensor, 0, len(raws))
for name, raw := range raws {
if !strings.HasPrefix(name, "model.layer") {
if name == "__metadata__" {
// TODO(bmizerany): do something with metadata?
continue
}
var v struct {
@@ -112,7 +115,8 @@ func (m *Model) readTensors(fname string) ([]*Tensor, error) {
// TODO(bmizerany): after collecting, validate all offests make
// tensors contiguous?
begin, end := v.Offsets[0], v.Offsets[1]
begin := endOfHeader + v.Offsets[0]
end := endOfHeader + v.Offsets[1]
if err := checkBeginEnd(finfo.Size(), begin, end); err != nil {
return nil, err
}

View File

@@ -68,7 +68,7 @@ func main() {
log.Fatal(err)
}
rc, err := ollama.RegistryFromEnv()
rc, err := ollama.DefaultRegistry()
if err != nil {
log.Fatal(err)
}
@@ -177,7 +177,7 @@ func cmdPush(ctx context.Context, rc *ollama.Registry, c *blob.DiskCache) error
}
from := cmp.Or(*flagFrom, model)
m, err := ollama.ResolveLocal(c, from)
m, err := rc.ResolveLocal(c, from)
if err != nil {
return err
}
@@ -228,6 +228,10 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
flag.PrintDefaults()
}
flag.Parse(args)
if *flagAs == "" {
return fmt.Errorf("missing -as flag")
}
as := ollama.CompleteName(*flagAs)
dir := cmp.Or(flag.Arg(0), ".")
fmt.Fprintf(os.Stderr, "Reading %s\n", dir)
@@ -311,7 +315,7 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
if err != nil {
return err
}
return c.Link(*flagAs, d)
return c.Link(as, d)
}()
}()
@@ -340,6 +344,8 @@ func cmdImport(ctx context.Context, c *blob.DiskCache) error {
writeProgress()
case err := <-done:
writeProgress()
fmt.Println()
fmt.Println("Successfully imported", as)
return err
}
}

View File

@@ -8,7 +8,7 @@ import (
"github.com/ollama/ollama/server/internal/internal/stringsx"
)
const MaxNameLength = 50 + 1 + 50 + 1 + 50 // <namespace>/<model>:<tag>
const MaxNameLength = 350 + 1 + 80 + 1 + 80 + 1 + 80 // <host>/<namespace>/<model>:<tag>
type Name struct {
// Make incomparable to enfoce use of Compare / Equal for
@@ -25,19 +25,12 @@ type Name struct {
// format of a valid name string is:
//
// s:
// { host } "/" { namespace } "/" { model } ":" { tag } "@" { digest }
// { host } "/" { namespace } "/" { model } ":" { tag }
// { host } "/" { namespace } "/" { model } "@" { digest }
// { host } "/" { namespace } "/" { model }
// { namespace } "/" { model } ":" { tag } "@" { digest }
// { namespace } "/" { model } ":" { tag }
// { namespace } "/" { model } "@" { digest }
// { namespace } "/" { model }
// { model } ":" { tag } "@" { digest }
// { model } ":" { tag }
// { model } "@" { digest }
// { model }
// "@" { digest }
// host:
// pattern: { alphanum | "_" } { alphanum | "_" | "-" | "." | ":" }*
// length: [1, 350]
@@ -50,9 +43,6 @@ type Name struct {
// tag:
// pattern: { alphanum | "_" } { alphanum | "-" | "_" | "." }*
// length: [1, 80]
// digest:
// pattern: { alphanum | "_" } { alphanum | "-" | ":" }*
// length: [1, 80]
//
// The name returned is not guaranteed to be valid. If it is not valid, the
// field values are left in an undefined state. Use [Name.IsValid] to check
@@ -82,23 +72,17 @@ func Parse(s string) Name {
}
}
// ParseExtended parses and returns any scheme, Name, and digest from from s in
// the the form [scheme://][name][@digest]. All parts are optional.
//
// If the scheme is present, it must be followed by "://". The digest is
// prefixed by "@" and comes after the name. The name is parsed using [Parse].
//
// The scheme and digest are stripped before the name is parsed by [Parse].
//
// For convience, the scheme is never empty. If the scheme is not present, the
// returned scheme is "https".
// Split splits an extended name string into its scheme, name, and digest
// parts.
//
// Examples:
//
// http://ollama.com/bmizerany/smol:latest@digest
// https://ollama.com/bmizerany/smol:latest
// ollama.com/bmizerany/smol:latest@digest // returns "https" scheme.
func ParseExtended(s string) (scheme string, _ Name, digest string) {
// model@digest
// @digest
func Split(s string) (scheme, name, digest string) {
i := strings.Index(s, "://")
if i >= 0 {
scheme = s[:i]
@@ -109,21 +93,7 @@ func ParseExtended(s string) (scheme string, _ Name, digest string) {
digest = s[i+1:]
s = s[:i]
}
return scheme, Parse(s), digest
}
func FormatExtended(scheme string, n Name, digest string) string {
var b strings.Builder
if scheme != "" {
b.WriteString(scheme)
b.WriteString("://")
}
b.WriteString(n.String())
if digest != "" {
b.WriteByte('@')
b.WriteString(digest)
}
return b.String()
return scheme, s, digest
}
// Merge merges two names into a single name. Non-empty host, namespace, and
@@ -141,39 +111,68 @@ func Merge(a, b Name) Name {
// IsValid returns true if the name is valid.
func (n Name) IsValid() bool {
if n.h != "" && !isValidHost(n.h) {
if n.h != "" && !isValidPart(partHost, n.h) {
return false
}
if n.n != "" && !isValidNamespace(n.n) {
if n.n != "" && !isValidPart(partNamespace, n.n) {
return false
}
if n.m != "" && !isValidModel(n.m) {
if n.t != "" && !isValidPart(partTag, n.t) {
return false
}
if n.t != "" && !isValidTag(n.t) {
return false
}
return true
// at bare minimum, model must be present and valid
return n.m != "" && isValidPart(partModel, n.m)
}
func (n Name) IsFullyQualified() bool {
return n.IsValid() && n.h != "" && n.n != "" && n.m != "" && n.t != ""
}
func isValidHost(_ string) bool {
return true // TODO: implement
const (
partHost = iota
partNamespace
partModel
partTag
)
func isValidPart(kind int, s string) bool {
maxlen := 80
if kind == partHost {
maxlen = 350
}
if len(s) > maxlen {
return false
}
for i := range s {
if i == 0 {
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
continue
}
switch s[i] {
case '_', '-':
case '.':
if kind == partNamespace {
return false
}
case ':':
if kind != partHost {
return false
}
default:
if !isAlphanumericOrUnderscore(s[i]) {
return false
}
}
}
return true
}
func isValidNamespace(_ string) bool {
return true // TODO: implement
}
func isValidModel(_ string) bool {
return true // TODO: implement
}
func isValidTag(_ string) bool {
return true // TODO: implement
func isAlphanumericOrUnderscore(c byte) bool {
return c >= 'A' && c <= 'Z' || c >= 'a' && c <= 'z' || c >= '0' && c <= '9' || c == '_'
}
func (n Name) Host() string { return n.h }

View File

@@ -81,15 +81,11 @@ func TestParseExtended(t *testing.T) {
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
scheme, name, digest := ParseExtended(tt.in)
if scheme != tt.wantScheme || name.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
scheme, name, digest := Split(tt.in)
n := Parse(name)
if scheme != tt.wantScheme || n.Compare(tt.wantName) != 0 || digest != tt.wantDigest {
t.Errorf("ParseExtended(%q) = %q, %#v, %q, want %q, %#v, %q", tt.in, scheme, name, digest, tt.wantScheme, tt.wantName, tt.wantDigest)
}
// Round trip
if got := FormatExtended(scheme, name, digest); got != tt.in {
t.Errorf("FormatExtended(%q, %q, %q) = %q", scheme, name, digest, got)
}
})
}
}
@@ -150,3 +146,75 @@ func BenchmarkParseName(b *testing.B) {
junkName = Parse("h/n/m:t")
}
}
const (
part80 = "88888888888888888888888888888888888888888888888888888888888888888888888888888888"
part350 = "33333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333333"
)
var testCases = map[string]bool{ // name -> valid
"": false,
"_why/_the/_lucky:_stiff": true,
// minimal
"h/n/m:t": true,
"host/namespace/model:tag": true,
"host/namespace/model": true,
"namespace/model": true,
"model": true,
// long (but valid)
part80 + "/" + part80 + "/" + part80 + ":" + part80: true,
part350 + "/" + part80 + "/" + part80 + ":" + part80: true,
// too long
part80 + "/" + part80 + "/" + part80 + ":" + part350: false,
"x" + part350 + "/" + part80 + "/" + part80 + ":" + part80: false,
"h/nn/mm:t": true, // bare minimum part sizes
// unqualified
"m": true,
"n/m:": true,
"h/n/m": true,
"@t": false,
"m@d": false,
// invalids
"^": false,
"mm:": true,
"/nn/mm": true,
"//": false, // empty model
"//mm": true,
"hh//": false, // empty model
"//mm:@": false,
"00@": false,
"@": false,
// not starting with alphanum
"-hh/nn/mm:tt": false,
"hh/-nn/mm:tt": false,
"hh/nn/-mm:tt": false,
"hh/nn/mm:-tt": false,
// smells like a flag
"-h": false,
// hosts
"host:https/namespace/model:tag": true,
// colon in non-host part before tag
"host/name:space/model:tag": false,
}
func TestParseNameValidation(t *testing.T) {
for s, valid := range testCases {
got := Parse(s)
if got.IsValid() != valid {
t.Logf("got: %v", got)
t.Errorf("Parse(%q).IsValid() = %v; want !%[2]v", s, got.IsValid())
}
}
}

View File

@@ -0,0 +1,235 @@
// Package registry provides an http.Handler for handling local Ollama API
// requests for performing tasks related to the ollama.com model registry and
// the local disk cache.
package registry
import (
"cmp"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
)
// Local is an http.Handler for handling local Ollama API requests for
// performing tasks related to the ollama.com model registry combined with the
// local disk cache.
//
// It is not concern of Local, or this package, to handle model creation, which
// proceeds any registry operations for models it produces.
//
// NOTE: The package built for dealing with model creation should use
// [DefaultCache] to access the blob store and not attempt to read or write
// directly to the blob disk cache.
type Local struct {
Client *ollama.Registry // required
Cache *blob.DiskCache // required
Logger *slog.Logger // required
// Fallback, if set, is used to handle requests that are not handled by
// this handler.
Fallback http.Handler
}
// serverError is like ollama.Error, but with a Status field for the HTTP
// response code. We want to avoid adding that field to ollama.Error because it
// would always be 0 to clients (we don't want to leak the status code in
// errors), and so it would be confusing to have a field that is always 0.
type serverError struct {
Status int `json:"-"`
// TODO(bmizerany): Decide if we want to keep this and maybe
// bring back later.
Code string `json:"code"`
Message string `json:"error"`
}
func (e serverError) Error() string {
return e.Message
}
// Common API errors
var (
errMethodNotAllowed = &serverError{405, "method_not_allowed", "method not allowed"}
errNotFound = &serverError{404, "not_found", "not found"}
errInternalError = &serverError{500, "internal_error", "internal server error"}
)
type statusCodeRecorder struct {
_status int // use status() to get the status code
http.ResponseWriter
}
func (r *statusCodeRecorder) WriteHeader(status int) {
if r._status == 0 {
r._status = status
}
r.ResponseWriter.WriteHeader(status)
}
var (
_ http.ResponseWriter = (*statusCodeRecorder)(nil)
_ http.CloseNotifier = (*statusCodeRecorder)(nil)
_ http.Flusher = (*statusCodeRecorder)(nil)
)
// CloseNotify implements the http.CloseNotifier interface, for Gin. Remove with Gin.
//
// It panics if the underlying ResponseWriter is not a CloseNotifier.
func (r *statusCodeRecorder) CloseNotify() <-chan bool {
return r.ResponseWriter.(http.CloseNotifier).CloseNotify()
}
// Flush implements the http.Flusher interface, for Gin. Remove with Gin.
//
// It panics if the underlying ResponseWriter is not a Flusher.
func (r *statusCodeRecorder) Flush() {
r.ResponseWriter.(http.Flusher).Flush()
}
func (r *statusCodeRecorder) status() int {
return cmp.Or(r._status, 200)
}
func (s *Local) ServeHTTP(w http.ResponseWriter, r *http.Request) {
rec := &statusCodeRecorder{ResponseWriter: w}
s.serveHTTP(rec, r)
}
func (s *Local) serveHTTP(rec *statusCodeRecorder, r *http.Request) {
var errattr slog.Attr
proxied, err := func() (bool, error) {
switch r.URL.Path {
case "/api/delete":
return false, s.handleDelete(rec, r)
default:
if s.Fallback != nil {
s.Fallback.ServeHTTP(rec, r)
return true, nil
}
return false, errNotFound
}
}()
if err != nil {
// We always log the error, so fill in the error log attribute
errattr = slog.String("error", err.Error())
var e *serverError
switch {
case errors.As(err, &e):
case errors.Is(err, ollama.ErrNameInvalid):
e = &serverError{400, "bad_request", err.Error()}
default:
e = errInternalError
}
data, err := json.Marshal(e)
if err != nil {
// unreachable
panic(err)
}
rec.Header().Set("Content-Type", "application/json")
rec.WriteHeader(e.Status)
rec.Write(data)
// fallthrough to log
}
if !proxied {
// we're only responsible for logging if we handled the request
var level slog.Level
if rec.status() >= 500 {
level = slog.LevelError
} else if rec.status() >= 400 {
level = slog.LevelWarn
}
s.Logger.LogAttrs(r.Context(), level, "http",
errattr, // report first in line to make it easy to find
// TODO(bmizerany): Write a test to ensure that we are logging
// all of this correctly. That also goes for the level+error
// logic above.
slog.Int("status", rec.status()),
slog.String("method", r.Method),
slog.String("path", r.URL.Path),
slog.Int64("content-length", r.ContentLength),
slog.String("remote", r.RemoteAddr),
slog.String("proto", r.Proto),
slog.String("query", r.URL.RawQuery),
)
}
}
type params struct {
DeprecatedName string `json:"name"` // Use [params.model]
Model string `json:"model"` // Use [params.model]
// AllowNonTLS is a flag that indicates a client using HTTP
// is doing so, deliberately.
//
// Deprecated: This field is ignored and only present for this
// deprecation message. It should be removed in a future release.
//
// Users can just use http or https+insecure to show intent to
// communicate they want to do insecure things, without awkward and
// confusing flags such as this.
AllowNonTLS bool `json:"insecure"`
// ProgressStream is a flag that indicates the client is expecting a stream of
// progress updates.
ProgressStream bool `json:"stream"`
}
// model returns the model name for both old and new API requests.
func (p params) model() string {
return cmp.Or(p.Model, p.DeprecatedName)
}
func (s *Local) handleDelete(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "DELETE" {
return errMethodNotAllowed
}
p, err := decodeUserJSON[*params](r.Body)
if err != nil {
return err
}
ok, err := s.Client.Unlink(s.Cache, p.model())
if err != nil {
return err
}
if !ok {
return &serverError{404, "not_found", "model not found"}
}
return nil
}
func decodeUserJSON[T any](r io.Reader) (T, error) {
var v T
err := json.NewDecoder(r).Decode(&v)
if err == nil {
return v, nil
}
var zero T
// Not sure why, but I can't seem to be able to use:
//
// errors.As(err, &json.UnmarshalTypeError{})
//
// This is working fine in stdlib, so I'm not sure what rules changed
// and why this no longer works here. So, we do it the verbose way.
var a *json.UnmarshalTypeError
var b *json.SyntaxError
if errors.As(err, &a) || errors.As(err, &b) {
err = &serverError{Status: 400, Message: err.Error(), Code: "bad_request"}
}
if errors.Is(err, io.EOF) {
err = &serverError{Status: 400, Message: "empty request body", Code: "bad_request"}
}
return zero, err
}

View File

@@ -0,0 +1,165 @@
package registry
import (
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"regexp"
"strings"
"testing"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/testutil"
)
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
// bytesResetter is an interface for types that can be reset and return a byte
// slice, only. This is to prevent inadvertent use of bytes.Buffer.Read/Write
// etc for the purpose of checking logs.
type bytesResetter interface {
Bytes() []byte
Reset()
}
func newTestServer(t *testing.T) *Local {
t.Helper()
dir := t.TempDir()
err := os.CopyFS(dir, os.DirFS("testdata/models"))
if err != nil {
t.Fatal(err)
}
c, err := blob.Open(dir)
if err != nil {
t.Fatal(err)
}
rc := &ollama.Registry{
HTTPClient: panicOnRoundTrip,
}
l := &Local{
Cache: c,
Client: rc,
Logger: testutil.Slogger(t),
}
return l
}
func (s *Local) send(t *testing.T, method, path, body string) *httptest.ResponseRecorder {
t.Helper()
req := httptest.NewRequestWithContext(t.Context(), method, path, strings.NewReader(body))
return s.sendRequest(t, req)
}
func (s *Local) sendRequest(t *testing.T, req *http.Request) *httptest.ResponseRecorder {
t.Helper()
w := httptest.NewRecorder()
s.ServeHTTP(w, req)
return w
}
type invalidReader struct{}
func (r *invalidReader) Read(p []byte) (int, error) {
return 0, os.ErrInvalid
}
// captureLogs is a helper to capture logs from the server. It returns a
// shallow copy of the server with a new logger and a bytesResetter for the
// logs.
func captureLogs(t *testing.T, s *Local) (*Local, bytesResetter) {
t.Helper()
log, logs := testutil.SlogBuffer()
l := *s // shallow copy
l.Logger = log
return &l, logs
}
func TestServerDelete(t *testing.T) {
check := testutil.Checker(t)
s := newTestServer(t)
_, err := s.Client.ResolveLocal(s.Cache, "smol")
check(err)
got := s.send(t, "DELETE", "/api/delete", `{"model": "smol"}`)
if got.Code != 200 {
t.Fatalf("Code = %d; want 200", got.Code)
}
_, err = s.Client.ResolveLocal(s.Cache, "smol")
if err == nil {
t.Fatal("expected smol to have been deleted")
}
got = s.send(t, "DELETE", "/api/delete", `!`)
checkErrorResponse(t, got, 400, "bad_request", "invalid character '!' looking for beginning of value")
got = s.send(t, "GET", "/api/delete", `{"model": "smol"}`)
checkErrorResponse(t, got, 405, "method_not_allowed", "method not allowed")
got = s.send(t, "DELETE", "/api/delete", ``)
checkErrorResponse(t, got, 400, "bad_request", "empty request body")
got = s.send(t, "DELETE", "/api/delete", `{"model": "://"}`)
checkErrorResponse(t, got, 400, "bad_request", "invalid or missing name")
got = s.send(t, "DELETE", "/unknown_path", `{}`) // valid body
checkErrorResponse(t, got, 404, "not_found", "not found")
s, logs := captureLogs(t, s)
req := httptest.NewRequestWithContext(t.Context(), "DELETE", "/api/delete", &invalidReader{})
got = s.sendRequest(t, req)
checkErrorResponse(t, got, 500, "internal_error", "internal server error")
ok, err := regexp.Match(`ERROR.*error="invalid argument"`, logs.Bytes())
check(err)
if !ok {
t.Logf("logs:\n%s", logs)
t.Fatalf("expected log to contain ERROR with invalid argument")
}
}
func TestServerUnknownPath(t *testing.T) {
s := newTestServer(t)
got := s.send(t, "DELETE", "/api/unknown", `{}`)
checkErrorResponse(t, got, 404, "not_found", "not found")
}
func checkErrorResponse(t *testing.T, got *httptest.ResponseRecorder, status int, code, msg string) {
t.Helper()
var printedBody bool
errorf := func(format string, args ...any) {
t.Helper()
if !printedBody {
t.Logf("BODY:\n%s", got.Body.String())
printedBody = true
}
t.Errorf(format, args...)
}
if got.Code != status {
errorf("Code = %d; want %d", got.Code, status)
}
// unmarshal the error as *ollama.Error (proving *serverError is an *ollama.Error)
var e *ollama.Error
if err := json.Unmarshal(got.Body.Bytes(), &e); err != nil {
errorf("unmarshal error: %v", err)
t.FailNow()
}
if e.Code != code {
errorf("Code = %q; want %q", e.Code, code)
}
if !strings.Contains(e.Message, msg) {
errorf("Message = %q; want to contain %q", e.Message, msg)
}
}

View File

@@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@@ -0,0 +1 @@
{"schemaVersion":2,"mediaType":"application/vnd.docker.distribution.manifest.v2+json","config":{"mediaType":"application/vnd.docker.container.image.v1+json","digest":"sha256:ca239d7bd8ea90e4a5d2e6bf88f8d74a47b14336e73eb4e18bed4dd325018116","size":267},"layers":[{"mediaType":"application/vnd.ollama.image.model","digest":"sha256:a4e5e156ddec27e286f75328784d7106b60a4eb1d246e950a001a3f944fbda99","size":24}]}

View File

@@ -1,12 +1,40 @@
package testutil
import (
"bytes"
"io"
"log/slog"
"os"
"path/filepath"
"testing"
"time"
)
// LogWriter returns an [io.Writer] that logs each Write using t.Log.
func LogWriter(t *testing.T) io.Writer {
return testWriter{t}
}
type testWriter struct{ t *testing.T }
func (w testWriter) Write(b []byte) (int, error) {
w.t.Logf("%s", b)
return len(b), nil
}
// Slogger returns a [*slog.Logger] that writes each message
// using t.Log.
func Slogger(t *testing.T) *slog.Logger {
return slog.New(slog.NewTextHandler(LogWriter(t), nil))
}
// SlogBuffer returns a [*slog.Logger] that writes each message to out.
func SlogBuffer() (lg *slog.Logger, out *bytes.Buffer) {
var buf bytes.Buffer
lg = slog.New(slog.NewTextHandler(&buf, nil))
return lg, &buf
}
// Check calls t.Fatal(err) if err is not nil.
func Check(t *testing.T, err error) {
if err != nil {

View File

@@ -34,6 +34,9 @@ import (
"github.com/ollama/ollama/llm"
"github.com/ollama/ollama/model/models/mllama"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/server/internal/registry"
"github.com/ollama/ollama/template"
"github.com/ollama/ollama/types/errtypes"
"github.com/ollama/ollama/types/model"
@@ -1126,7 +1129,7 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
}
}
func (s *Server) GenerateRoutes() http.Handler {
func (s *Server) GenerateRoutes(c *blob.DiskCache, rc *ollama.Registry) (http.Handler, error) {
corsConfig := cors.DefaultConfig()
corsConfig.AllowWildcard = true
corsConfig.AllowBrowserExtensions = true
@@ -1165,10 +1168,9 @@ func (s *Server) GenerateRoutes() http.Handler {
r.HEAD("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
r.GET("/api/version", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"version": version.Version}) })
// Local model cache management
// Local model cache management (new implementation is at end of function)
r.POST("/api/pull", s.PullHandler)
r.POST("/api/push", s.PushHandler)
r.DELETE("/api/delete", s.DeleteHandler)
r.HEAD("/api/tags", s.ListHandler)
r.GET("/api/tags", s.ListHandler)
r.POST("/api/show", s.ShowHandler)
@@ -1193,7 +1195,15 @@ func (s *Server) GenerateRoutes() http.Handler {
r.GET("/v1/models", openai.ListMiddleware(), s.ListHandler)
r.GET("/v1/models/:model", openai.RetrieveMiddleware(), s.ShowHandler)
return r
// wrap old with new
rs := &registry.Local{
Cache: c,
Client: rc,
Logger: slog.Default(), // TODO(bmizerany): Take a logger, do not use slog.Default()
Fallback: r,
}
return rs, nil
}
func Serve(ln net.Listener) error {
@@ -1246,12 +1256,27 @@ func Serve(ln net.Listener) error {
}
}
s := &Server{addr: ln.Addr()}
c, err := ollama.DefaultCache()
if err != nil {
return err
}
rc, err := ollama.DefaultRegistry()
if err != nil {
return err
}
h, err := s.GenerateRoutes(c, rc)
if err != nil {
return err
}
http.Handle("/", h)
ctx, done := context.WithCancel(context.Background())
schedCtx, schedDone := context.WithCancel(ctx)
sched := InitScheduler(schedCtx)
s := &Server{addr: ln.Addr(), sched: sched}
http.Handle("/", s.GenerateRoutes())
s.sched = sched
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
srvr := &http.Server{

View File

@@ -23,6 +23,8 @@ import (
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/fs/ggml"
"github.com/ollama/ollama/openai"
"github.com/ollama/ollama/server/internal/cache/blob"
"github.com/ollama/ollama/server/internal/client/ollama"
"github.com/ollama/ollama/types/model"
"github.com/ollama/ollama/version"
)
@@ -91,7 +93,15 @@ func equalStringSlices(a, b []string) bool {
return true
}
func Test_Routes(t *testing.T) {
type panicTransport struct{}
func (t *panicTransport) RoundTrip(r *http.Request) (*http.Response, error) {
panic("unexpected RoundTrip call")
}
var panicOnRoundTrip = &http.Client{Transport: &panicTransport{}}
func TestRoutes(t *testing.T) {
type testCase struct {
Name string
Method string
@@ -241,10 +251,10 @@ func Test_Routes(t *testing.T) {
Method: http.MethodDelete,
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "model-to-delete")
createTestModel(t, "model_to_delete")
deleteReq := api.DeleteRequest{
Name: "model-to-delete",
Name: "model_to_delete",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
@@ -271,7 +281,7 @@ func Test_Routes(t *testing.T) {
Path: "/api/delete",
Setup: func(t *testing.T, req *http.Request) {
deleteReq := api.DeleteRequest{
Name: "non-existent-model",
Name: "non_existent_model",
}
jsonData, err := json.Marshal(deleteReq)
if err != nil {
@@ -477,10 +487,34 @@ func Test_Routes(t *testing.T) {
},
}
t.Setenv("OLLAMA_MODELS", t.TempDir())
modelsDir := t.TempDir()
t.Setenv("OLLAMA_MODELS", modelsDir)
c, err := blob.Open(modelsDir)
if err != nil {
t.Fatalf("failed to open models dir: %v", err)
}
rc := &ollama.Registry{
// This is a temporary measure to allow us to move forward,
// surfacing any code contacting ollama.com we do not intended
// to.
//
// Currently, this only handles DELETE /api/delete, which
// should not make any contact with the ollama.com registry, so
// be clear about that.
//
// Tests that do need to contact the registry here, will be
// consumed into our new server/api code packages and removed
// from here.
HTTPClient: panicOnRoundTrip,
}
s := &Server{}
router := s.GenerateRoutes()
router, err := s.GenerateRoutes(c, rc)
if err != nil {
t.Fatalf("failed to generate routes: %v", err)
}
httpSrv := httptest.NewServer(router)
t.Cleanup(httpSrv.Close)