Compare commits

..

26 Commits

Author SHA1 Message Date
Blake Mizerany
49c126fde8 build.go: introduce a friendlier way to build Ollama
This commit introduces a more friendly way to build Ollama dependencies
and the binary without abusing `go generate` and removing the
unnecessary extra steps it brings with it.

This script also provides nicer feedback to the user about what is
happening during the build process.

At the end, it prints a helpful message to the user about what to do
next (e.g. run the new local Ollama).
2024-04-09 13:52:08 -07:00
writinwaters
1341ee1b56 Update README.md (#3539)
RAGFlow now supports integration with Ollama.
2024-04-08 10:58:14 -04:00
Jeffrey Morgan
63efa075a0 update generate scripts with new LLAMA_CUDA variable, set HIP_PLATFORM to avoid compiler errors (#3528) 2024-04-07 19:29:51 -04:00
Thomas Vitale
cb03fc9571 Docs: Remove wrong parameter for Chat Completion (#3515)
Fixes gh-3514

Signed-off-by: Thomas Vitale <ThomasVitale@users.noreply.github.com>
2024-04-06 09:08:35 -07:00
Michael Yang
a5ec9cfc0f Merge pull request #3508 from ollama/mxyng/rope 2024-04-05 18:46:06 -07:00
Michael Yang
be517e491c no rope parameters 2024-04-05 18:05:27 -07:00
Michael Yang
fc8e108642 Merge pull request #3496 from ollama/mxyng/cmd-r-graph
add command-r graph estimate
2024-04-05 12:26:21 -07:00
Daniel Hiltgen
c5d5c4a96c Merge pull request #3491 from dhiltgen/context_bust_test
Add test case for context exhaustion
2024-04-04 16:20:20 -07:00
Daniel Hiltgen
dfe330fa1c Merge pull request #3488 from mofanke/fix-windows-dll-compress
fix dll compress in windows building
2024-04-04 16:12:13 -07:00
Michael Yang
01f77ae25d add command-r graph estimate 2024-04-04 14:07:24 -07:00
Daniel Hiltgen
483b81a863 Merge pull request #3494 from dhiltgen/ci_release
Fail fast if mingw missing on windows
2024-04-04 10:15:40 -07:00
Daniel Hiltgen
36bd967722 Fail fast if mingw missing on windows 2024-04-04 09:51:26 -07:00
Jeffrey Morgan
b0e7d35db8 use an older version of the mac os sdk in release (#3484) 2024-04-04 09:48:54 -07:00
Daniel Hiltgen
aeb1fb5192 Add test case for context exhaustion
Confirmed this fails on 0.1.30 with known regression
but passes on main
2024-04-04 07:42:17 -07:00
Daniel Hiltgen
a2e60ebcaf Merge pull request #3490 from dhiltgen/ci_fixes
CI missing archive
2024-04-04 07:24:24 -07:00
Daniel Hiltgen
883ec4d1ef CI missing archive 2024-04-04 07:23:27 -07:00
mofanke
4de0126719 fix dll compress in windows building 2024-04-04 21:27:33 +08:00
Daniel Hiltgen
9768e2dc75 Merge pull request #3481 from dhiltgen/ci_fixes
CI subprocess path fix
2024-04-03 19:29:09 -07:00
Daniel Hiltgen
08600d5bec CI subprocess path fix 2024-04-03 19:12:53 -07:00
Daniel Hiltgen
a624e672d2 Merge pull request #3479 from dhiltgen/ci_fixes
Fix CI release glitches
2024-04-03 18:42:27 -07:00
Daniel Hiltgen
e4a7e5b2ca Fix CI release glitches
The subprocess change moved the build directory
arm64 builds weren't setting cross-compilation flags when building on x86
2024-04-03 16:41:40 -07:00
Michael Yang
a0a15cfd5b Merge pull request #3463 from ollama/mxyng/graph-estimate
update graph size estimate
2024-04-03 14:27:30 -07:00
Michael Yang
12e923e158 update graph size estimate 2024-04-03 13:34:12 -07:00
Jeffrey Morgan
cd135317d2 Fix macOS builds on older SDKs (#3467) 2024-04-03 10:45:54 -07:00
Michael Yang
4f895d633f Merge pull request #3466 from ollama/mxyng/head-kv
default head_kv to 1
2024-04-03 10:41:00 -07:00
Michael Yang
90f071c658 default head_kv to 1 2024-04-02 16:37:59 -07:00
69 changed files with 434 additions and 5416 deletions

View File

@@ -8,7 +8,7 @@ on:
jobs:
# Full build of the Mac assets
build-darwin:
runs-on: macos-latest
runs-on: macos-12
environment: release
steps:
- uses: actions/checkout@v4
@@ -38,9 +38,11 @@ jobs:
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
APPLE_ID: ${{ vars.APPLE_ID }}
SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
run: |
./scripts/build_darwin.sh
- uses: actions/upload-artifact@v4
with:
name: dist-darwin
@@ -48,7 +50,6 @@ jobs:
dist/*arwin*
!dist/*-cov
# Windows builds take a long time to both install the dependencies and build, so parallelize
# CPU generation step
generate-windows-cpu:
@@ -94,12 +95,15 @@ jobs:
cd $env:GITHUB_WORKSPACE
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$env:PATH"
go generate -x ./...
$env:GOARCH = ""; go run build.go -f -d -target=${{ matrix.arch }}
name: go generate
- uses: actions/upload-artifact@v4
with:
name: generate-windows-cpu
path: llm/llama.cpp/build/**/lib/*
path: |
llm/build/**/bin/*
llm/build/**/*.a
# ROCm generation step
generate-windows-rocm:
@@ -138,7 +142,7 @@ jobs:
with:
go-version: '1.22'
cache: true
- name: "Install ROCm"
- name: 'Install ROCm'
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
@@ -146,7 +150,7 @@ jobs:
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"
- name: "Verify ROCm"
- name: 'Verify ROCm'
run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
- run: go get ./...
@@ -160,7 +164,7 @@ jobs:
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
go generate -x ./...
name: go generate
- name: "gather rocm dependencies"
- name: 'gather rocm dependencies'
run: |
$HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
md "dist\deps\bin\rocblas\library"
@@ -170,7 +174,7 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-rocm
path: llm/llama.cpp/build/**/lib/*
path: llm/build/**/bin/*
- uses: actions/upload-artifact@v4
with:
name: windows-rocm-deps
@@ -213,7 +217,7 @@ jobs:
with:
go-version: '1.22'
cache: true
- name: "Install CUDA"
- name: 'Install CUDA'
run: |
$ErrorActionPreference = "Stop"
write-host "downloading CUDA Installer"
@@ -227,7 +231,7 @@ jobs:
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
- name: "Verify CUDA"
- name: 'Verify CUDA'
run: nvcc -V
- run: go get ./...
- name: go generate
@@ -240,7 +244,7 @@ jobs:
$env:PATH="$gopath;$cudabin;$env:PATH"
$env:OLLAMA_SKIP_CPU_GENERATE="1"
go generate -x ./...
- name: "gather cuda dependencies"
- name: 'gather cuda dependencies'
run: |
$NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0]
md "dist\deps"
@@ -250,7 +254,7 @@ jobs:
- uses: actions/upload-artifact@v4
with:
name: generate-windows-cuda
path: llm/llama.cpp/build/**/lib/*
path: llm/build/**/bin/*
- uses: actions/upload-artifact@v4
with:
name: windows-cuda-deps
@@ -303,11 +307,11 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: generate-windows-cpu
path: llm/llama.cpp/build
path: llm/build
- uses: actions/download-artifact@v4
with:
name: generate-windows-cuda
path: llm/llama.cpp/build
path: llm/build
- uses: actions/download-artifact@v4
with:
name: windows-cuda-deps
@@ -319,8 +323,8 @@ jobs:
- uses: actions/download-artifact@v4
with:
name: generate-windows-rocm
path: llm/llama.cpp/build
- run: dir llm/llama.cpp/build
path: llm/build
- run: dir llm/build
- run: |
$gopath=(get-command go).source | split-path -parent
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
@@ -336,14 +340,14 @@ jobs:
name: dist-windows
path: dist/*.exe
# Linux x86 assets built using the container based build
# Linux x86 assets built using the container based build
build-linux-amd64:
environment: release
runs-on: linux
env:
OLLAMA_SKIP_MANIFEST_CREATE: "1"
OLLAMA_SKIP_MANIFEST_CREATE: '1'
BUILD_ARCH: amd64
PUSH: "1"
PUSH: '1'
steps:
- uses: actions/checkout@v4
with:
@@ -373,9 +377,9 @@ jobs:
environment: release
runs-on: linux-arm64
env:
OLLAMA_SKIP_MANIFEST_CREATE: "1"
OLLAMA_SKIP_MANIFEST_CREATE: '1'
BUILD_ARCH: arm64
PUSH: "1"
PUSH: '1'
steps:
- uses: actions/checkout@v4
with:
@@ -383,7 +387,7 @@ jobs:
- name: Set Version
shell: bash
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
- name: "Install Docker"
- name: 'Install Docker'
run: |
# Add Docker's official GPG key:
env
@@ -420,7 +424,7 @@ jobs:
!dist/*-cov
# Aggregate all the assets and ship a release
release:
release:
needs:
- build-darwin
- build-windows
@@ -431,8 +435,8 @@ jobs:
permissions:
contents: write
env:
OLLAMA_SKIP_IMAGE_BUILD: "1"
PUSH: "1"
OLLAMA_SKIP_IMAGE_BUILD: '1'
PUSH: '1'
steps:
- uses: actions/checkout@v4
- name: Set Version
@@ -460,11 +464,11 @@ jobs:
with:
name: ${{ env.RELEASE_VERSION }}
allowUpdates: true
artifacts: "dist/*"
artifacts: 'dist/*'
draft: true
prerelease: true
omitBodyDuringUpdate: true
generateReleaseNotes: true
omitDraftDuringUpdate: true
omitPrereleaseDuringUpdate: true
replacesArtifacts: true
replacesArtifacts: true

View File

@@ -1,5 +1,16 @@
name: test
concurrency:
# For PRs, later CI runs preempt previous ones. e.g. a force push on a PR
# cancels running CI jobs and starts all new ones.
#
# For non-PR pushes, concurrency.group needs to be unique for every distinct
# CI run we want to have happen. Use run_id, which in practice means all
# non-PR CI runs will be allowed to run without preempting each other.
group: ${{ github.workflow }}-$${{ github.pull_request.number || github.run_id }}
cancel-in-progress: true
on:
pull_request:
paths:
@@ -62,12 +73,14 @@ jobs:
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$gccpath;$env:PATH"
echo $env:PATH
go generate -x ./...
$env:GOARCH = ""; go run build.go -f -d -target=${{ matrix.arch }}
if: ${{ startsWith(matrix.os, 'windows-') }}
name: "Windows Go Generate"
- run: go generate -x ./...
name: 'Windows Go Generate'
- run: |
GOARCH= go run build.go -f -d -target=${{ matrix.arch }}
if: ${{ ! startsWith(matrix.os, 'windows-') }}
name: "Unix Go Generate"
name: 'Unix Go Generate'
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
@@ -98,7 +111,7 @@ jobs:
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
GOARCH= go run build.go -f -d -target=${{ matrix.arch }}
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
@@ -129,13 +142,13 @@ jobs:
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
GOARCH= go run build.go -f -d -target=${{ matrix.arch }}
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/build/**/lib/*
path: llm/build/**/bin/*
# ROCm generation step
generate-windows-rocm:
@@ -148,7 +161,7 @@ jobs:
with:
go-version: '1.22'
cache: true
- name: "Install ROCm"
- name: 'Install ROCm'
run: |
$ErrorActionPreference = "Stop"
write-host "downloading AMD HIP Installer"
@@ -156,7 +169,7 @@ jobs:
write-host "Installing AMD HIP"
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
write-host "Completed AMD HIP"
- name: "Verify ROCm"
- name: 'Verify ROCm'
run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
- run: go get ./...
@@ -168,8 +181,9 @@ jobs:
$env:PATH="$gopath;$env:PATH"
$env:OLLAMA_SKIP_CPU_GENERATE="1"
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
go generate -x ./...
name: go generate
$env:GOARCH = ""; go run build.go -f -d -target=${{ matrix.arch }}
name: go run build.go
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
# TODO - do we need any artifacts?
@@ -185,7 +199,7 @@ jobs:
with:
go-version: '1.22'
cache: true
- name: "Install CUDA"
- name: 'Install CUDA'
run: |
$ErrorActionPreference = "Stop"
write-host "downloading CUDA Installer"
@@ -199,10 +213,10 @@ jobs:
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
- name: "Verify CUDA"
- name: 'Verify CUDA'
run: nvcc -V
- run: go get ./...
- name: go generate
- name: go run build.go
run: |
$gopath=(get-command go).source | split-path -parent
$cudabin=(get-command nvcc).source | split-path
@@ -211,12 +225,12 @@ jobs:
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
$env:PATH="$gopath;$cudabin;$env:PATH"
$env:OLLAMA_SKIP_CPU_GENERATE="1"
go generate -x ./...
$env:GOARCH = ""; go run build.go -f -d -target=${{ matrix.arch }}
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
# TODO - do we need any artifacts?
lint:
strategy:
matrix:
@@ -248,18 +262,18 @@ jobs:
esac >>$GITHUB_ENV
shell: bash
- run: |
mkdir -p llm/build/linux/$ARCH/stub/bin/
touch llm/build/linux/$ARCH/stub/bin/stub.so
mkdir -p llm/build/linux/$ARCH/stub/bin
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
- run: |
mkdir -p llm/build/darwin/$ARCH/stub/bin/
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
touch llm/ggml-metal.metal
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
shell: bash
- uses: golangci/golangci-lint-action@v4
with:
args: --timeout 8m0s
@@ -277,7 +291,7 @@ jobs:
env:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: '1'
OLLAMA_CPU_TARGET: "static"
OLLAMA_CPU_TARGET: 'static'
steps:
- uses: actions/checkout@v4
with:
@@ -286,6 +300,12 @@ jobs:
with:
go-version: '1.22'
cache: true
- run: |
GOARCH= go run build.go -f -d -target=${{ matrix.arch }}
if: ${{ ! startsWith(matrix.os, 'windows-') }}
- run: |
$env:GOARCH = ""; go run build.go -f -d -target=${{ matrix.arch }}
if: ${{ startsWith(matrix.os, 'windows-') }}
- run: go get
- run: |
case ${{ matrix.arch }} in
@@ -294,21 +314,20 @@ jobs:
esac >>$GITHUB_ENV
shell: bash
- run: |
mkdir -p llm/build/linux/$ARCH/stub/bin/
touch llm//build/linux/$ARCH/stub/bin/stub.so
mkdir -p llm/build/linux/$ARCH/stub/bin
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
- run: |
mkdir -p llm/build/darwin/$ARCH/stub/bin/
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
touch llm/ggml-metal.metal
mkdir -p llm/build/darwin/$ARCH/stub/bin
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
mkdir -p llm/build/windows/$ARCH/stub/bin
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
if: ${{ startsWith(matrix.os, 'windows-') }}
- run: go generate ./...
- run: go build
- run: go test -v ./...
shell: bash
- run: |
go test -v ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-binaries

View File

@@ -201,16 +201,10 @@ Install `cmake` and `go`:
brew install cmake go
```
Then generate dependencies:
```
go generate ./...
```
Then build the binary:
```
go build .
go run build.go
```
More detailed instructions can be found in the [developer guide](https://github.com/ollama/ollama/blob/main/docs/development.md)
@@ -292,6 +286,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
### Terminal

View File

@@ -121,8 +121,6 @@ type Runner struct {
VocabOnly bool `json:"vocab_only,omitempty"`
UseMMap bool `json:"use_mmap,omitempty"`
UseMLock bool `json:"use_mlock,omitempty"`
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
NumThread int `json:"num_thread,omitempty"`
}
@@ -383,8 +381,6 @@ func DefaultOptions() Options {
Runner: Runner{
// options set when the model is loaded
NumCtx: 2048,
RopeFrequencyBase: 10000.0,
RopeFrequencyScale: 1.0,
NumBatch: 512,
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
NumGQA: 1,

192
build.go Normal file
View File

@@ -0,0 +1,192 @@
//go:build ignore
package main
import (
"cmp"
"errors"
"flag"
"log"
"os"
"os/exec"
"path/filepath"
"runtime"
)
// Flags
var (
flagForce = flag.Bool("f", false, "force re-generation of dependencies")
flagSkipBuild = flag.Bool("d", false, "generate dependencies only (e.g. skip 'go build .')")
// Flags to set GOARCH and GOOS explicitly for cross-platform builds,
// e.g., in CI to target a different platform than the build matrix
// default. These allows us to run generate without a separate build
// step for building the script binary for the host ARCH and then
// runing the generate script for the target ARCH. Instead, we can
// just run `go run build.go -target=$GOARCH` to generate the
// deps.
flagGOARCH = flag.String("target", "", "sets GOARCH to use when generating dependencies and building")
)
func buildEnv() []string {
return append(os.Environ(),
"GOARCH="+cmp.Or(*flagGOARCH, runtime.GOARCH),
)
}
func main() {
log.SetFlags(0)
flag.Usage = func() {
log.Printf("Usage: go run build.go [flags]")
log.Println()
log.Println("Flags:")
flag.PrintDefaults()
log.Println()
log.Println("This script builds the Ollama server binary and generates the llama.cpp")
log.Println("bindings for the current platform. It assumes that the current working")
log.Println("directory is the root directory of the Ollama project.")
log.Println()
log.Println("If the -d flag is provided, the script will only generate the dependencies")
log.Println("and skip building the Ollama server binary.")
log.Println()
log.Println("If the -f flag is provided, the script will force re-generation of the")
log.Println("dependencies.")
log.Println()
log.Println("If the -target flag is provided, the script will set GOARCH to the value")
log.Println("of the flag. This is useful for cross-platform builds.")
log.Println()
log.Println("The script will check for the required dependencies (cmake, gcc) and")
log.Println("print their version.")
log.Println()
log.Println("The script will also check if it is being run from the root directory of")
log.Println("the Ollama project.")
log.Println()
os.Exit(1)
}
flag.Parse()
log.Printf("=== Building Ollama ===")
defer func() {
log.Printf("=== Done building Ollama ===")
log.Println()
log.Println("To run the Ollama server, use:")
log.Println()
log.Println(" ./ollama serve")
log.Println()
}()
if flag.NArg() > 0 {
flag.Usage()
}
if !inRootDir() {
log.Fatalf("Please run this script from the root directory of the Ollama project.")
}
if err := checkDependencies(); err != nil {
log.Fatalf("Failed dependency check: %v", err)
}
if err := buildLlammaCPP(); err != nil {
log.Fatalf("Failed to build llama.cpp: %v", err)
}
if err := goBuildOllama(); err != nil {
log.Fatalf("Failed to build ollama Go binary: %v", err)
}
}
// checkDependencies does a quick check to see if the required dependencies are
// installed on the system and functioning enough to print their version.
//
// TODO(bmizerany): Check the actual version of the dependencies? Seems a
// little daunting given diff versions might print diff things. This should
// be good enough for now.
func checkDependencies() error {
var err error
check := func(name string, args ...string) {
log.Printf("=== Checking for %s ===", name)
defer log.Printf("=== Done checking for %s ===\n\n", name)
cmd := exec.Command(name, args...)
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = errors.Join(err, cmd.Run())
}
check("cmake", "--version")
check("gcc", "--version")
return err
}
func goBuildOllama() error {
log.Println("=== Building Ollama binary ===")
defer log.Printf("=== Done building Ollama binary ===\n\n")
if *flagSkipBuild {
log.Println("Skipping 'go build -o ollama .'")
return nil
}
cmd := exec.Command("go", "build", "-o", "ollama", ".")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = buildEnv()
return cmd.Run()
}
// buildLlammaCPP generates the llama.cpp bindings for the current platform.
//
// It assumes that the current working directory is the root directory of the
// Ollama project.
func buildLlammaCPP() error {
log.Println("=== Generating dependencies ===")
defer log.Printf("=== Done generating dependencies ===\n\n")
if *flagForce {
if err := os.RemoveAll(filepath.Join("llm", "build")); err != nil {
return err
}
}
if isDirectory(filepath.Join("llm", "build")) {
log.Println("llm/build already exists; skipping. Use -f to force re-generate.")
return nil
}
scriptDir, err := filepath.Abs(filepath.Join("llm", "generate"))
if err != nil {
return err
}
var cmd *exec.Cmd
switch runtime.GOOS {
case "windows":
script := filepath.Join(scriptDir, "gen_windows.ps1")
cmd = exec.Command("powershell", "-ExecutionPolicy", "Bypass", "-File", script)
case "linux":
script := filepath.Join(scriptDir, "gen_linux.sh")
cmd = exec.Command("bash", script)
case "darwin":
script := filepath.Join(scriptDir, "gen_darwin.sh")
cmd = exec.Command("bash", script)
default:
log.Fatalf("Unsupported OS: %s", runtime.GOOS)
}
cmd.Dir = filepath.Join("llm", "generate")
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
cmd.Env = buildEnv()
log.Printf("Running GOOS=%s GOARCH=%s %s", runtime.GOOS, runtime.GOARCH, cmd.Args)
return cmd.Run()
}
func isDirectory(path string) bool {
info, err := os.Stat(path)
if err != nil {
return false
}
return info.IsDir()
}
// inRootDir returns true if the current working directory is the root
// directory of the Ollama project. It looks for a file named "go.mod".
func inRootDir() bool {
_, err := os.Stat("go.mod")
return err == nil
}

View File

@@ -32,7 +32,6 @@ type Params struct {
AttentionHeads int `json:"num_attention_heads"` // n_head
KeyValHeads int `json:"num_key_value_heads"`
NormEPS float64 `json:"rms_norm_eps"`
RopeFreqBase float64 `json:"rope_theta"`
BoSTokenID int `json:"bos_token_id"`
EoSTokenID int `json:"eos_token_id"`
HeadDimension int `json:"head_dim"`

View File

@@ -144,7 +144,6 @@ func (m *MistralModel) WriteGGUF() (string, error) {
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
"llama.rope.freq_base": float32(m.Params.RopeFreqBase),
"general.file_type": uint32(1),
"tokenizer.ggml.model": "llama",

View File

@@ -394,7 +394,6 @@ Advanced parameters (optional):
- `format`: the format to return a response in. Currently the only accepted value is `json`
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)

View File

@@ -23,13 +23,7 @@ export OLLAMA_DEBUG=1
Get the required libraries and build the native LLM code:
```bash
go generate ./...
```
Then build ollama:
```bash
go build .
go run build.go
```
Now you can run `ollama`:
@@ -38,6 +32,16 @@ Now you can run `ollama`:
./ollama
```
### Rebuilding the native code
If at any point you need to rebuild the native code, you can run the
build.go script again using the `-f` flag to force a rebuild, and,
optionally, the `-d` flag to skip building the Go binary:
```bash
go run build.go -f -d
```
### Linux
#### Linux CUDA (NVIDIA)
@@ -53,16 +57,10 @@ specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Then generate dependencies:
```
go generate ./...
```
Then build the binary:
```
go build .
go run build.go
```
#### Linux ROCm (AMD)
@@ -78,21 +76,17 @@ install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
```
go generate ./...
```
Then build the binary:
```
go build .
go run build.go
```
ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
#### Advanced CPU Settings
By default, running `go generate ./...` will compile a few different variations
By default, running `go run build.go` will compile a few different variations
of the LLM library based on common CPU families and vector math capabilities,
including a lowest-common-denominator which should run on almost any 64 bit CPU
somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to
@@ -102,8 +96,7 @@ like to use. For example, to compile an optimized binary for an Intel i9-9880H,
you might use:
```
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
go build .
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go run build.go
```
#### Containerized Linux Build
@@ -124,8 +117,7 @@ Install required tools:
```powershell
$env:CGO_ENABLED="1"
go generate ./...
go build .
go run build.go
```
#### Windows CUDA (NVIDIA)
@@ -142,4 +134,4 @@ In addition to the common Windows development tools described above, install AMD
- [AMD HIP](https://www.amd.com/en/developer/resources/rocm-hub/hip-sdk.html)
- [Strawberry Perl](https://strawberryperl.com/)
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).
Lastly, add `ninja.exe` included with MSVC to the system path (e.g. `C:\Program Files (x86)\Microsoft Visual Studio\2019\Community\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja`).

28
go.mod
View File

@@ -10,7 +10,7 @@ require (
github.com/emirpasic/gods v1.18.1
github.com/gin-gonic/gin v1.9.1
github.com/golang/protobuf v1.5.0 // indirect
github.com/google/uuid v1.6.0
github.com/google/uuid v1.0.0
github.com/mitchellh/mapstructure v1.5.0
github.com/olekukonko/tablewriter v0.0.5
github.com/spf13/cobra v1.7.0
@@ -19,35 +19,23 @@ require (
golang.org/x/sync v0.3.0
)
require (
github.com/minio/minio-go/v7 v7.0.69
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
kr.dev/diff v0.3.0
)
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
require (
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
github.com/chewxy/hm v1.0.0 // indirect
github.com/chewxy/math32 v1.0.8 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/gogo/protobuf v1.3.2 // indirect
github.com/google/flatbuffers v1.12.0 // indirect
github.com/klauspost/compress v1.17.6 // indirect
github.com/mattn/go-runewidth v0.0.14 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/rivo/uniseg v0.2.0 // indirect
github.com/rogpeppe/go-internal v1.8.1 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/xtgo/set v1.0.0 // indirect
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
gonum.org/v1/gonum v0.8.2 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
gorgonia.org/vecf32 v0.9.0 // indirect
gorgonia.org/vecf64 v0.9.0 // indirect
)
@@ -65,7 +53,7 @@ require (
github.com/google/go-cmp v0.5.9 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
github.com/leodido/go-urn v1.2.4 // indirect
github.com/mattn/go-isatty v0.0.19 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
@@ -75,12 +63,12 @@ require (
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.11 // indirect
golang.org/x/arch v0.3.0 // indirect
golang.org/x/crypto v0.19.0
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0
golang.org/x/term v0.17.0
golang.org/x/text v0.14.0 // indirect
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0
golang.org/x/term v0.13.0
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0
gopkg.in/yaml.v3 v3.0.1 // indirect
)

51
go.sum
View File

@@ -26,8 +26,6 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
@@ -88,8 +86,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@@ -97,12 +95,9 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
@@ -120,12 +115,6 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@@ -140,7 +129,6 @@ github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6u
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@@ -150,11 +138,8 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
@@ -196,8 +181,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@@ -220,8 +205,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@@ -241,18 +226,18 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
@@ -307,8 +292,6 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
@@ -320,6 +303,4 @@ gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=

View File

@@ -0,0 +1,29 @@
//go:build integration
package integration
import (
"context"
"net/http"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func TestContextExhaustion(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
defer cancel()
// Set up the test data
req := api.GenerateRequest{
Model: "llama2",
Prompt: "Write me a story with a ton of emojis?",
Stream: &stream,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_ctx": 128,
},
}
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
}

View File

@@ -15,10 +15,6 @@ import (
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
// package to avoid circular dependencies
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
//
// TODO - Fix this ^^
var (
stream = false
req = [2]api.GenerateRequest{

View File

@@ -1,6 +1,6 @@
#!/bin/bash
# This script is intended to run inside the go generate
# working directory must be ./llm/generate/
# This script is intended to run inside the `go run build.go` script, which
# sets the working directory to the correct location: ./llm/generate/.
# TODO - add hardening to detect missing tools (cmake, etc.)
@@ -18,7 +18,7 @@ sign() {
fi
}
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
case "${GOARCH}" in
"amd64")
@@ -41,7 +41,7 @@ case "${GOARCH}" in
BUILD_DIR="../build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${BUILD_DIR}/lib/libext_server.dylib
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
@@ -53,7 +53,7 @@ case "${GOARCH}" in
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${BUILD_DIR}/lib/libext_server.dylib
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
#
@@ -66,7 +66,7 @@ case "${GOARCH}" in
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${BUILD_DIR}/lib/libext_server.dylib
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;;
"arm64")
@@ -74,25 +74,25 @@ case "${GOARCH}" in
# Static build for linking into the Go binary
init_vars
CMAKE_TARGETS="--target llama --target ggml"
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}_static"
echo "Building static library"
build
init_vars
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="../build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${BUILD_DIR}/lib/libext_server.dylib
sign ${BUILD_DIR}/bin/ollama_llama_server
compress
;;
*)
echo "GOARCH must be set"
echo "this script is meant to be run from within go generate"
echo "this script is meant to be run from within 'go run build.go'"
exit 1
;;
esac
cleanup
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
echo "code generation completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"

View File

@@ -1,6 +1,6 @@
#!/bin/bash
# This script is intended to run inside the go generate
# working directory must be llm/generate/
# This script is intended to run with the `go run build.go` script, which
# sets the working directory to the correct location: ./llm/generate/.
# First we build one or more CPU based LLM libraries
#
@@ -172,7 +172,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
# Disabling has minimal performance effect while maintaining compatibility.
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
fi
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build
@@ -237,4 +237,4 @@ if [ -d "${ROCM_PATH}" ]; then
fi
cleanup
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
echo "code generation completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"

View File

@@ -146,7 +146,7 @@ function compress {
}
write-host "Compressing dlls..."
$binaries = dir "${script:buildDir}/bin/*.dll"
$dlls = dir "${script:buildDir}/bin/*.dll"
foreach ($file in $dlls) {
& "$script:GZIP" --best -f $file
}
@@ -183,9 +183,17 @@ if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
# GCC build for direct linking into the Go binary
init_vars
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
# as we need this to be compiled by gcc for golang to be able to link with itx
write-host "Checking for MinGW..."
# error action ensures we exit on failure
get-command gcc
get-command mingw32-make
$script:cmakeTargets = @("llama", "ggml")
$script:cmakeDefs = @(
"-G", "MinGW Makefiles"
"-DCMAKE_C_COMPILER=gcc.exe",
"-DCMAKE_CXX_COMPILER=g++.exe",
"-DBUILD_SHARED_LIBS=off",
"-DLLAMA_NATIVE=off",
"-DLLAMA_AVX=off",
@@ -234,7 +242,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
}
init_vars
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build
sign
compress
@@ -253,6 +261,7 @@ if ($null -ne $env:HIP_PATH) {
"-DCMAKE_C_COMPILER=clang.exe",
"-DCMAKE_CXX_COMPILER=clang++.exe",
"-DLLAMA_HIPBLAS=on",
"-DHIP_PLATFORM=amd",
"-DLLAMA_AVX=on",
"-DLLAMA_AVX2=off",
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
@@ -279,4 +288,4 @@ if ($null -ne $env:HIP_PATH) {
cleanup
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
write-host "`ncode generation completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"

View File

@@ -1,3 +0,0 @@
package generate
//go:generate bash ./gen_darwin.sh

View File

@@ -1,3 +0,0 @@
package generate
//go:generate bash ./gen_linux.sh

View File

@@ -1,3 +0,0 @@
package generate
//go:generate powershell -ExecutionPolicy Bypass -File ./gen_windows.ps1

View File

@@ -148,15 +148,15 @@ func (kv KV) HeadCount() uint64 {
}
func (kv KV) HeadCountKV() uint64 {
return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
return headCountKV
}
return 1
}
func (kv KV) GQA() uint64 {
if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
return kv.HeadCount() / headCountKV
}
return 0
return kv.HeadCount() / kv.HeadCountKV()
}
func (kv KV) EmbeddingLength() uint64 {
@@ -303,3 +303,50 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
model: model,
}, offset, nil
}
func (llm GGML) GraphSize(context, batch int) (int64, bool) {
embeddingLength := llm.KV().EmbeddingLength()
headCount := llm.KV().HeadCount()
headCountKV := llm.KV().HeadCountKV()
vocabLength := len(llm.KV()["tokenizer.ggml.tokens"].([]any))
var attnQKVWeight1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.HasSuffix(t.Name, ".attn_qkv.weight") && len(t.Shape) >= 2 {
attnQKVWeight1 = t.Shape[1]
break
}
}
var ffnGate1 uint64 = 0
for _, t := range llm.Tensors() {
if strings.Index(t.Name, ".ffn_gate") > 0 && len(t.Shape) >= 2 {
ffnGate1 = t.Shape[1]
break
}
}
switch llm.KV().Architecture() {
case "gemma", "command-r":
return 4 * int64(batch) * int64(embeddingLength+uint64(vocabLength)), true
case "phi2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+4*embeddingLength+uint64(context)+attnQKVWeight1+uint64(context)*headCount),
), true
case "qwen2":
return max(
4*int64(batch)*int64(embeddingLength+uint64(vocabLength)),
4*int64(batch)*int64(1+2*embeddingLength+uint64(context)+uint64(context)*headCount),
), true
case "llama":
if ffnGate1 > 0 {
// moe
return 4 * int64(batch) * int64(2+3*embeddingLength+uint64(context)+uint64(context)*headCount+2*headCountKV+ffnGate1), true
}
return 4 * int64(batch) * int64(1+4*embeddingLength+uint64(context)+uint64(context)*headCount), true
}
return 0, false
}

View File

@@ -79,10 +79,11 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
// this amount is the overhead + tensors in memory
// TODO: get this from the llama.cpp's graph calculations instead of
// estimating it's 1/6 * kv_cache_size * num_gqa
graph := int64(ggml.KV().GQA()) * kv / 6
graph, ok := ggml.GraphSize(opts.NumCtx, min(opts.NumCtx, opts.NumBatch))
if !ok {
graph = int64(ggml.KV().GQA()) * kv / 6
}
usedMemory += graph
if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {
@@ -171,14 +172,6 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
}
if opts.RopeFrequencyBase > 0 {
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
}
if opts.RopeFrequencyScale > 0 {
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
}
if len(adapters) > 0 {
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
params = append(params, "--lora", adapters[0])

View File

@@ -1,113 +0,0 @@
package api
import (
"errors"
"fmt"
"net/http"
"os"
"github.com/ollama/ollama/x/build"
"github.com/ollama/ollama/x/client/ollama/apitype"
"github.com/ollama/ollama/x/oweb"
"github.com/ollama/ollama/x/registry"
regtype "github.com/ollama/ollama/x/registry/apitype"
)
// Common API Errors
var (
errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
)
type Server struct {
Build *build.Server
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
oweb.Serve(s.serveHTTP, w, r)
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/push":
return s.handlePush(w, r)
default:
return oweb.ErrNotFound
}
}
func want(r *http.Request, method, path string) bool {
return r.Method == method && r.URL.Path == path
}
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
if r.Method != "POST" {
return oweb.ErrMethodNotAllowed
}
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
if err != nil {
return err
}
if params.Name == "" {
return oweb.Missing("name")
}
const registryURLTODO = "http://localhost:8888"
man, err := s.Build.ManifestData(params.Name)
if err != nil {
if errors.Is(err, build.ErrNotFound) {
return errRefNotFound
}
return err
}
c := registry.Client{BaseURL: registryURLTODO}
requirements, err := c.Push(r.Context(), params.Name, man, nil)
if err != nil {
return err
}
var uploads []regtype.CompletePart
for _, rq := range requirements {
l, err := s.Build.LayerFile(rq.Digest)
if err != nil {
return err
}
err = func() error {
f, err := os.Open(l)
if err != nil {
return err
}
defer f.Close()
cp, err := registry.PushLayer(r.Context(), f, rq.URL, rq.Offset, rq.Size)
if err != nil {
return err
}
uploads = append(uploads, cp)
return nil
}()
if err != nil {
return err
}
}
// commit the manifest to the registry
requirements, err = c.Push(r.Context(), params.Name, man, &registry.PushParams{
CompleteParts: uploads,
})
if err != nil {
return err
}
for _, r := range requirements {
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
}
return err
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
return oweb.ErrNotFound
}

View File

@@ -1,209 +0,0 @@
package build
import (
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"github.com/ollama/ollama/x/build/internal/blobstore"
"github.com/ollama/ollama/x/model"
)
// Errors
var (
ErrIncompleteRef = errors.New("unqualified ref")
ErrBuildPresentInRef = errors.New("build present in ref")
ErrUnsupportedModelFormat = errors.New("unsupported model format")
ErrMissingFileType = errors.New("missing 'general.file_type' key")
ErrNotFound = errors.New("not found")
)
type mediaType string
// Known media types
const (
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
)
type Server struct {
st *blobstore.Store
}
// Open starts a new build server that uses dir as the base directory for all
// build artifacts. If dir is empty, DefaultDir is used.
//
// It returns an error if the provided or default dir cannot be initialized.
func Open(dir string) (*Server, error) {
if dir == "" {
var err error
dir, err = DefaultDir()
if err != nil {
return nil, err
}
}
st, err := blobstore.Open(dir)
if err != nil {
return nil, err
}
return &Server{st: st}, nil
}
func (s *Server) Build(ref string, f model.File) error {
mp := model.ParseName(ref)
if !mp.IsCompleteNoBuild() {
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
}
// 1. Resolve FROM
// a. If it's a local file (gguf), hash it and add it to the store.
// c. If it's a remote file (http), refuse.
// 2. Turn other pragmas into layers, and add them to the store.
// 3. Create a manifest from the layers.
// 4. Store the manifest in the manifest cache
// 5. Done.
if f.From == "" {
return &model.FileError{Pragma: "FROM", Message: "missing"}
}
var layers []layerJSON
id, info, size, err := s.importModel(f.From)
if err != nil {
return err
}
layers = append(layers, layerJSON{
ID: id,
MediaType: mediaTypeModel,
Size: size,
})
id, size, err = blobstore.PutString(s.st, f.License)
if err != nil {
return err
}
layers = append(layers, layerJSON{
ID: id,
MediaType: "text/plain",
Size: size,
})
data, err := json.Marshal(manifestJSON{Layers: layers})
if err != nil {
return err
}
return s.setManifestData(
mp.WithBuild(info.FileType.String()),
data,
)
}
func (s *Server) LayerFile(digest string) (string, error) {
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
_, err := os.Stat(fileName)
if errors.Is(err, fs.ErrNotExist) {
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
}
return fileName, nil
}
func (s *Server) ManifestData(ref string) ([]byte, error) {
data, _, err := s.resolve(model.ParseName(ref))
return data, err
}
// WeightFile returns the absolute path to the weights file for the given model ref.
func (s *Server) WeightsFile(ref string) (string, error) {
m, err := s.getManifest(model.ParseName(ref))
if err != nil {
return "", err
}
for _, l := range m.Layers {
if l.MediaType == mediaTypeModel {
return s.st.OutputFilename(l.ID), nil
}
}
return "", fmt.Errorf("missing weights layer for %q", ref)
}
// resolve returns the data for the given ref, if any.
//
// TODO: This should ideally return an ID, but the current on
// disk layout is that the actual manifest is stored in the "ref" instead of
// a pointer to a content-addressed blob. I (bmizerany) think we should
// change the on-disk layout to store the manifest in a content-addressed
// blob, and then have the ref point to that blob. This would simplify the
// code, allow us to have integrity checks on the manifest, and clean up
// this interface.
func (s *Server) resolve(ref model.Name) (data []byte, fileName string, err error) {
fileName, err = s.refFileName(ref)
if err != nil {
return nil, "", err
}
data, err = os.ReadFile(fileName)
if errors.Is(err, fs.ErrNotExist) {
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
}
if err != nil {
// do not wrap the error here, as it is likely an I/O error
// and we want to preserve the absraction since we may not
// be on disk later.
return nil, "", fmt.Errorf("manifest read error: %v", err)
}
return data, fileName, nil
}
func (s *Server) SetManifestData(ref string, data []byte) error {
return s.setManifestData(model.ParseName(ref), data)
}
// Set sets the data for the given ref.
func (s *Server) setManifestData(mp model.Name, data []byte) error {
path, err := s.refFileName(mp)
if err != nil {
return err
}
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
return err
}
if err := os.WriteFile(path, data, 0666); err != nil {
return err
}
return nil
}
func (s *Server) refFileName(mp model.Name) (string, error) {
if !mp.IsComplete() {
return "", fmt.Errorf("ref not fully qualified: %q", mp)
}
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(mp.Parts()...)), nil
}
type manifestJSON struct {
// Layers is the list of layers in the manifest.
Layers []layerJSON `json:"layers"`
}
// Layer is a layer in a model manifest.
type layerJSON struct {
// ID is the ID of the layer.
ID blobstore.ID `json:"digest"`
MediaType mediaType `json:"mediaType"`
Size int64 `json:"size"`
}
func (s *Server) getManifest(ref model.Name) (manifestJSON, error) {
data, path, err := s.resolve(ref)
if err != nil {
return manifestJSON{}, err
}
var m manifestJSON
if err := json.Unmarshal(data, &m); err != nil {
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
}
return m, nil
}

View File

@@ -1,163 +0,0 @@
package build
import (
"errors"
"os"
"path/filepath"
"testing"
"github.com/ollama/ollama/x/encoding/gguf"
"github.com/ollama/ollama/x/model"
)
const qualifiedRef = "x/y/z:latest+Q4_0"
func TestServerBuildErrors(t *testing.T) {
dir := t.TempDir()
s, err := Open(dir)
if err != nil {
t.Fatal(err)
}
t.Run("unqualified ref", func(t *testing.T) {
err := s.Build("x", model.File{})
if !errors.Is(err, ErrIncompleteRef) {
t.Fatalf("Build() err = %v; want unqualified ref", err)
}
})
t.Run("FROM pragma missing", func(t *testing.T) {
err := s.Build(qualifiedRef, model.File{})
var e *model.FileError
if !errors.As(err, &e) {
t.Fatalf("unexpected error: %v", err)
}
if e.Pragma != "FROM" {
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
}
if e.Message != "missing" {
t.Errorf("e.Message = %s; want missing", e.Message)
}
})
t.Run("FROM file not found", func(t *testing.T) {
err := s.Build(qualifiedRef, model.File{From: "bar"})
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Build() err = %v; want file not found", err)
}
})
t.Run("FROM gguf", func(t *testing.T) {
w := newWorkDir(t)
// Write a gguf file without general.file_type metadata.
w.write("gguf", ""+
"GGUF"+ // magic
"\x03\x00\x00\x00"+ // version
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
"",
)
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
if !errors.Is(err, ErrMissingFileType) {
t.Fatalf("Build() err = %#v; want missing file type", err)
}
})
t.Run("FROM obscure dir", func(t *testing.T) {
w := newWorkDir(t)
w.mkdirAll("unknown")
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
})
t.Run("FROM unsupported model type", func(t *testing.T) {
w := newWorkDir(t)
from := w.write("unknown", "unknown content")
err := s.Build(qualifiedRef, model.File{From: from})
if !errors.Is(err, ErrUnsupportedModelFormat) {
t.Fatalf("Build() err = %#v; want unsupported model type", err)
}
})
}
func TestBuildBasicGGUF(t *testing.T) {
w := newWorkDir(t)
w.write("gguf", ""+
"GGUF"+ // magic
"\x03\x00\x00\x00"+ // version
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
// general.file_type key
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
"general.file_type"+ // key
"\x04\x00\x00\x00"+ // type (uint32)
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
"",
)
dir := t.TempDir()
s, err := Open(dir)
if err != nil {
t.Fatal(err)
}
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
t.Fatal(err)
}
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
t.Logf("file: %s", p)
return nil
})
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
if !errors.Is(err, ErrNotFound) {
t.Fatalf("WeightsFile() err = %v; want not found", err)
}
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
if err != nil {
t.Fatal(err)
}
info, err := gguf.Stat(path)
if err != nil {
t.Fatal(err)
}
if info.FileType != gguf.TypeQ4_0 {
t.Errorf("info.FileType = %d; want 1", info.FileType)
}
}
type work struct {
t testing.TB
dir string
}
func newWorkDir(t *testing.T) work {
return work{t: t, dir: t.TempDir()}
}
func (w work) write(name, content string) (path string) {
w.t.Helper()
path = w.fileName(name)
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
w.t.Fatal(err)
}
return path
}
func (w work) fileName(name string) string {
w.t.Helper()
return filepath.Join(w.dir, name)
}
func (w work) mkdirAll(path string) {
w.t.Helper()
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
w.t.Fatal(err)
}
}

View File

@@ -1,12 +0,0 @@
package build
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
// TODO: decine on hueristic for converting safetensor to gguf and
// the errors that can be returned. For now, we just say
// "unsupported", however it may be intended to be a valid safe
// tensor but we hit an error in the conversion.
//
// I (bmizernay) think this will naturally evolve as we implement
// the conversion.
return "", ErrUnsupportedModelFormat
}

View File

@@ -1,28 +0,0 @@
package build
import (
"os"
"path/filepath"
"sync"
)
var (
defaultDir = sync.OnceValues(func() (string, error) {
dir := os.Getenv("OLLAMA_MODELS")
if dir == "" {
home, err := os.UserHomeDir()
if err != nil {
return "", err
}
dir = filepath.Join(home, ".ollama", "models")
}
return dir, nil
})
)
// DefaultDir returns the default directory for models. It returns the value
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
// "$HOME/.ollama/models".
func DefaultDir() (string, error) {
return defaultDir()
}

View File

@@ -1,59 +0,0 @@
package build
import (
"errors"
"fmt"
"os"
"github.com/ollama/ollama/x/build/internal/blobstore"
"github.com/ollama/ollama/x/encoding/gguf"
)
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
return blobstore.ID{}, gguf.Info{}, 0, err
}
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
info, err := os.Stat(path)
if err != nil {
return importError(err)
}
if info.IsDir() {
return s.importSafeTensor(path)
} else {
return s.importGGUF(path)
}
}
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
f, err := os.Open(path)
if err != nil {
return importError(err)
}
defer f.Close()
info, err := gguf.StatReader(f)
if errors.Is(err, gguf.ErrBadMagic) {
return importError(ErrUnsupportedModelFormat)
}
if err != nil {
return importError(err)
}
if info.FileType == 0 {
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
}
id, size, err := s.st.Put(f)
if err != nil {
return importError(err)
}
return id, info, size, nil
}
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
path, err := convertSafeTensorToGGUF(path)
if err != nil {
return importError(err)
}
return s.importGGUF(path)
}

View File

@@ -1,329 +0,0 @@
// Package blobstore implements a blob store.
package blobstore
import (
"bytes"
"crypto/sha256"
"errors"
"fmt"
"io"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/ollama/ollama/x/types/structs"
)
var (
ErrInvalidID = errors.New("invalid ID")
)
const HashSize = 32
// An ID is a blob output key, the hash of an output of a computation.
type ID struct {
a [HashSize]byte
}
func (id ID) MarshalText() ([]byte, error) {
return []byte(id.String()), nil
}
func (id *ID) UnmarshalText(text []byte) error {
*id = ParseID(string(text))
return nil
}
func ParseID(s string) ID {
const prefix = "sha256-"
h, ok := strings.CutPrefix(s, prefix)
if !ok {
return ID{}
}
if len(h) != HashSize*2 {
return ID{}
}
var b []byte
_, err := fmt.Sscanf(h, "%x", &b)
if err != nil {
return ID{}
}
var id ID
copy(id.a[:], b)
return id
}
func (id ID) String() string {
if !id.Valid() {
return ""
}
return fmt.Sprintf("sha256-%x", id.a[:])
}
func (id ID) Valid() bool {
return id != ID{}
}
func (id ID) Match(h [HashSize]byte) bool {
return id.a == h
}
// A Store is a blob store, backed by a file system directory tree.
type Store struct {
dir string
now func() time.Time
}
// Open opens and returns the store in the given directory.
//
// It is safe for multiple processes on a single machine to use the
// same store directory in a local file system simultaneously.
// They will coordinate using operating system file locks and may
// duplicate effort but will not corrupt the store.
//
// However, it is NOT safe for multiple processes on different machines
// to share a store directory (for example, if the directory were stored
// in a network file system). File locking is notoriously unreliable in
// network file systems and may not suffice to protect the store.
func Open(dir string) (*Store, error) {
info, err := os.Stat(dir)
if err != nil {
return nil, err
}
if !info.IsDir() {
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
}
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
return nil, err
}
c := &Store{
dir: dir,
now: time.Now,
}
return c, nil
}
func (s *Store) Dir() string {
return s.dir
}
// fileName returns the name of the blob file corresponding to the given id.
func (s *Store) fileName(id ID) string {
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
}
// An entryNotFoundError indicates that a store entry was not found, with an
// optional underlying reason.
type entryNotFoundError struct {
Err error
}
func (e *entryNotFoundError) Error() string {
if e.Err == nil {
return "store entry not found"
}
return fmt.Sprintf("store entry not found: %v", e.Err)
}
func (e *entryNotFoundError) Unwrap() error {
return e.Err
}
type Entry struct {
_ structs.Incomparable
ID ID
Size int64
Time time.Time // when added to store
}
// GetFile looks up the blob ID in the store and returns
// the name of the corresponding data file.
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
entry, err = s.Get(id)
if err != nil {
return "", Entry{}, err
}
file = s.OutputFilename(entry.ID)
info, err := os.Stat(file)
if err != nil {
return "", Entry{}, &entryNotFoundError{Err: err}
}
if info.Size() != entry.Size {
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
}
return file, entry, nil
}
// GetBytes looks up the blob ID in the store and returns
// the corresponding output bytes.
// GetBytes should only be used for data that can be expected to fit in memory.
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
entry, err := s.Get(id)
if err != nil {
return nil, entry, err
}
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
if entry.ID.Match(sha256.Sum256(data)) {
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
}
return data, entry, nil
}
// OutputFilename returns the name of the blob file for the given ID.
func (s *Store) OutputFilename(id ID) string {
file := s.fileName(id)
// TODO(bmizerany): touch as "used" for cache trimming. (see
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
return file
}
// Get looks up the blob ID in the store,
// returning the corresponding output ID and file size, if any.
// Note that finding an output ID does not guarantee that the
// saved file for that output ID is still available.
func (s *Store) Get(id ID) (Entry, error) {
file := s.fileName(id)
info, err := os.Stat(file)
if err != nil {
return Entry{}, &entryNotFoundError{Err: err}
}
return Entry{
ID: id,
Size: info.Size(),
Time: info.ModTime(),
}, nil
}
func (s *Store) Close() error {
// TODO(bmizerany): return c.Trim()
return nil
}
// Put stores the data read from the given file into the store as ID.
//
// It may read file twice. The content of file must not change between the
// two passes.
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
return s.put(file)
}
func PutBytes(s *Store, data []byte) (ID, int64, error) {
return s.Put(bytes.NewReader(data))
}
func PutString(s *Store, data string) (ID, int64, error) {
return s.Put(strings.NewReader(data))
}
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
// Compute output ID.
h := sha256.New()
if _, err := file.Seek(0, 0); err != nil {
return ID{}, 0, err
}
size, err := io.Copy(h, file)
if err != nil {
return ID{}, 0, err
}
var out ID
h.Sum(out.a[:0])
// Copy to blob file (if not already present).
if err := s.copyFile(file, out, size); err != nil {
return out, size, err
}
// TODO: Add to manifest index.
return out, size, nil
}
// copyFile copies file into the store, expecting it to have the given
// output ID and size, if that file is not present already.
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
name := s.fileName(out)
println("name", name)
info, err := os.Stat(name)
if err == nil && info.Size() == size {
// Check hash.
if f, err := os.Open(name); err == nil {
h := sha256.New()
io.Copy(h, f)
f.Close()
var out2 ID
h.Sum(out2.a[:0])
if out == out2 {
return nil
}
}
// Hash did not match. Fall through and rewrite file.
}
// Copy file to blobs directory.
mode := os.O_RDWR | os.O_CREATE
if err == nil && info.Size() > size { // shouldn't happen but fix in case
mode |= os.O_TRUNC
}
f, err := os.OpenFile(name, mode, 0666)
if err != nil {
return err
}
defer f.Close()
if size == 0 {
// File now exists with correct size.
// Only one possible zero-length file, so contents are OK too.
// Early return here makes sure there's a "last byte" for code below.
return nil
}
// From here on, if any of the I/O writing the file fails,
// we make a best-effort attempt to truncate the file f
// before returning, to avoid leaving bad bytes in the file.
// Copy file to f, but also into h to double-check hash.
if _, err := file.Seek(0, 0); err != nil {
f.Truncate(0)
return err
}
h := sha256.New()
w := io.MultiWriter(f, h)
if _, err := io.CopyN(w, file, size-1); err != nil {
f.Truncate(0)
return err
}
// Check last byte before writing it; writing it will make the size match
// what other processes expect to find and might cause them to start
// using the file.
buf := make([]byte, 1)
if _, err := file.Read(buf); err != nil {
f.Truncate(0)
return err
}
h.Write(buf)
sum := h.Sum(nil)
if !bytes.Equal(sum, out.a[:]) {
f.Truncate(0)
return fmt.Errorf("file content changed underfoot")
}
// Commit manifest entry.
if _, err := f.Write(buf); err != nil {
f.Truncate(0)
return err
}
if err := f.Close(); err != nil {
// Data might not have been written,
// but file may look like it is the right size.
// To be extra careful, remove stored file.
os.Remove(name)
return err
}
os.Chtimes(name, s.now(), s.now()) // mainly for tests
return nil
}

View File

@@ -1,54 +0,0 @@
package blobstore
import (
"strings"
"testing"
)
func TestParseID(t *testing.T) {
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
var invalid = strings.Repeat("\x00", HashSize*2)
cases := []struct {
in string
want string
}{
{"", invalid},
{"sha256-", invalid},
{"sha256-" + valid, valid},
{"" + valid, invalid}, // no prefix
{"sha123-" + valid, invalid}, // invalid prefix
{"sha256-" + valid[1:], invalid}, // too short
{"sha256-" + valid + "a", invalid}, // too long
{"sha256-!" + valid[1:], invalid}, // invalid hex
}
for _, tt := range cases {
t.Run("", func(t *testing.T) {
// sanity check
if len(tt.want) > HashSize*2 {
panic("invalid test")
}
got := ParseID(tt.in)
wantValid := tt.want != invalid
if wantValid {
if !got.Valid() {
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
}
if got.String() != "sha256-"+tt.want {
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
}
} else {
if got.Valid() {
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
}
if got.String() != "" {
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
}
}
})
}
}

View File

@@ -1,128 +0,0 @@
package blobstore
import (
"errors"
"iter"
"os"
"path/filepath"
"testing"
"time"
"github.com/ollama/ollama/x/model"
"kr.dev/diff"
)
const (
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
)
func TestStoreBasicBlob(t *testing.T) {
dir := t.TempDir()
checkDir(t, dir, nil)
st, err := Open(dir)
if err != nil {
t.Fatal(err)
}
now := time.Now()
st.now = func() time.Time { return now }
checkDir(t, dir, []string{
"blobs/",
})
id, size, err := PutBytes(st, []byte("hello"))
if err != nil {
t.Fatal(err)
}
if id != ParseID(blobNameHello) {
t.Errorf("unexpected ID: %s", id)
}
if size != 5 {
t.Errorf("unexpected size: %d", size)
}
checkDir(t, dir, []string{
"blobs/",
"blobs/" + blobNameHello,
})
got, err := st.Get(id)
if err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, got, Entry{
ID: id,
Size: 5,
Time: now,
})
file := st.OutputFilename(id)
wantFile := filepath.Join(dir, "blobs", blobNameHello)
if file != wantFile {
t.Errorf("unexpected file: %s", file)
}
// Check tags
name := model.ParseName("registry.ollama.ai/library/test:latest+KQED")
t.Logf("RESOLVING: %q", name.Parts())
}
// checkDir checks that the directory at dir contains the files in want. The
// files in want must be relative to dir.
//
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
//
// want must be in lexicographic order.
func checkDir(t testing.TB, dir string, want []string) {
t.Helper()
var matches []string
for path, err := range walkDir(dir) {
t.Helper()
if err != nil {
t.Fatal(err)
}
t.Logf("found %s", path)
if path == "./" {
continue
}
path = filepath.ToSlash(path)
matches = append(matches, path)
}
diff.Test(t, t.Errorf, matches, want)
}
var errStop = errors.New("stop")
func walkDir(dir string) iter.Seq2[string, error] {
return func(yield func(string, error) bool) {
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
if err != nil {
return err
}
path, err = filepath.Rel(dir, path)
if err != nil {
return err
}
path = filepath.ToSlash(path)
if info.IsDir() {
path += "/"
}
if !yield(path, nil) {
return errStop
}
return nil
})
if !errors.Is(err, errStop) && err != nil {
yield("", err)
}
}
}

View File

@@ -1,31 +0,0 @@
package apitype
import "time"
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Model struct {
Ref string `json:"ref"`
Digest string `json:"digest"`
Size int64 `json:"size"`
ModifiedAt int64 `json:"modified"`
}
func (m Model) Modifed() time.Time {
return time.Unix(0, m.ModifiedAt)
}
type PushRequest struct {
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
Insecure bool `json:"insecure"`
Stream bool `json:"stream"`
}
type PushStatus struct {
Status string `json:"status"`
Digest string `json:"digest"`
Total int64 `json:"total"`
}

View File

@@ -1,173 +0,0 @@
package ollama
import (
"bytes"
"cmp"
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"iter"
"net/http"
"os"
"strings"
"github.com/ollama/ollama/x/client/ollama/apitype"
"github.com/ollama/ollama/x/types/empty"
)
// TODO(bmizerany): PROGRESS INDICATORS!!!!
const DefaultBaseURL = "http://localhost:11434"
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
// Default returns a new client with the default base URL.
func Default() *Client {
return &Client{BaseURL: envBaseURL}
}
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
// true for any instance of Client to work.
var I_Acknowledge_This_API_Is_Under_Development bool
// Client is a client for the Ollama API.
type Client struct {
// BaseURL is the base URL of the Ollama API.
BaseURL string
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
}
// Build requests the remote Ollama service to build a model. It uploads any
// source files the server needs.
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
panic("not implemented")
}
// Push requests the remote Ollama service to push a model to the server.
func (c *Client) Push(ctx context.Context, ref string) error {
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
return err
}
func (c *Client) Pull(ctx context.Context, ref string) error {
panic("not implemented")
}
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
panic("not implemented")
}
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
panic("not implemented")
}
func (c *Client) Remove(ctx context.Context, ref string) error {
panic("not implemented")
}
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
panic("not implemented")
}
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
panic("not implemented")
}
type Error struct {
// Status is the HTTP status code returned by the server.
Status int `json:"status"`
// Code specifies a machine readable code indicating the class of
// error this error is. See http://docs.ollama.com/errors for a full
// list of error codes.
Code string `json:"code"`
// Message is a humage readable message that describes the error. It
// may change across versions of the API, so it should not be used for
// programmatic decisions.
Message string `json:"message,omitempty"`
// Field is the field in the request that caused the error, if any.
Field string `json:"field,omitempty"`
// Value is the value of the field that caused the error, if any.
Value string `json:"value,omitempty"`
}
func (e *Error) Error() string {
var b strings.Builder
b.WriteString("ollama: ")
b.WriteString(e.Code)
if e.Field != "" {
b.WriteString(" ")
b.WriteString(e.Field)
}
if e.Value != "" {
b.WriteString(": ")
b.WriteString(e.Value)
}
if e.Message != "" {
b.WriteString(": ")
b.WriteString(e.Message)
}
return b.String()
}
// Do encodes in and sends it in a request to the Ollama server and decodes
// the response into Res, or an error response (non-2xx) into an *Error, or
// any error encounted decoding the response.
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
var body bytes.Buffer
// TODO(bmizerany): pool and reuse this buffer AND the encoder
if err := encodeJSON(&body, in); err != nil {
return nil, err
}
urlStr := c.BaseURL + path
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
if err != nil {
return nil, err
}
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
res, err := hc.Do(req)
if err != nil {
return nil, err
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
var buf bytes.Buffer
body := io.TeeReader(res.Body, &buf)
e, err := decodeJSON[Error](body)
if err != nil {
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
return nil, err
}
return nil, e
}
return decodeJSON[Res](res.Body)
}
// decodeJSON decodes JSON from r into a new value of type T.
//
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
// do not try and consolidate so we can keep ollama/client free from
// dependencies which are moving targets and not pulling enough weight to
// justify their inclusion.
func decodeJSON[T any](r io.Reader) (*T, error) {
var v T
if err := json.NewDecoder(r).Decode(&v); err != nil {
return nil, err
}
return &v, nil
}
// NOTE: see NOT above decodeJSON
func encodeJSON(w io.Writer, v any) error {
// TODO(bmizerany): pool and reuse encoder
return json.NewEncoder(w).Encode(v)
}

View File

@@ -1,100 +0,0 @@
// Bllamo is a (new) tool for managing Ollama models.
//
// Usage:
//
// bllamo <command> [arguments]
//
// The commands are:
//
// build build a model from a Modelfile
// list list all models
// push push a model from an ollama registry
// pull pull a model from an ollama registry
// delete delete a model from an ollama registry
// help display help for a command
package main
import (
"cmp"
"context"
"flag"
"fmt"
"net/http"
"os"
"github.com/ollama/ollama/x/api"
"github.com/ollama/ollama/x/build"
"github.com/ollama/ollama/x/client/ollama"
"github.com/ollama/ollama/x/registry"
)
func main() {
flag.Parse()
args := flag.Args()
if len(args) < 1 {
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
os.Exit(2)
}
if err := Main(flag.Args()...); err != nil {
fmt.Fprintf(os.Stderr, "%v\n", err)
os.Exit(1)
}
}
var TODOUsage = fmt.Errorf("TODO: usage")
var commands = map[string]func(ctx context.Context, args ...string) error{
"build": cmdBuild,
"push": cmdPush,
"serve": cmdServe,
"registry": cmdRegistry,
}
// Main is the entry point for the blammo command.
func Main(args ...string) error {
cmd := args[0]
args = args[1:]
if f, ok := commands[cmd]; ok {
ctx := context.TODO()
return f(ctx, args...)
}
return fmt.Errorf("blammo: unknown command %q", cmd)
}
func cmdBuild(ctx context.Context, args ...string) error {
var v struct {
Modelfile string `flag:"f,the Modelfile to use"`
}
fs := readFlags("build", args, &v)
if fs.NArg() != 1 {
return TODOUsage
}
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
if err != nil {
return err
}
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
}
func cmdRegistry(_ context.Context, _ ...string) error {
var s registry.Server
return http.ListenAndServe(":8888", &s)
}
func cmdServe(ctx context.Context, args ...string) error {
bs, err := build.Open("")
if err != nil {
return err
}
return http.ListenAndServe(":11434", &api.Server{Build: bs})
}
func cmdPush(ctx context.Context, args ...string) error {
fs := readFlags("push", args, nil)
if fs.NArg() != 1 {
return TODOUsage
}
return ollama.Default().Push(ctx, fs.Arg(0))
}

View File

@@ -1,59 +0,0 @@
package main
import (
"flag"
"fmt"
"reflect"
"strings"
)
// parseArgs parses the provided args using a flag.FlagSet that is
// dynamically build using reflection for the provided type. The type fields
// that have a "flag" tag are used to build the flags. The flag tag should
// include either a ('-'). Example usage:
//
// func main() {
// var flags struct {
// Modelfile string `flag:"f,path to the Modelfile"`
// }
//
// fs := readFlags(os.Args[1:], &flags)
// fs.Parse(os.Args[1:])
// }
func readFlags(name string, args []string, v any) *flag.FlagSet {
fs := flag.NewFlagSet(name, flag.ExitOnError)
defer fs.Parse(args)
if v == nil {
return fs
}
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
f := reflect.ValueOf(v).Field(i)
if !f.CanSet() {
continue
}
tag := f.Type().Field(i).Tag.Get("flag")
if tag == "" {
continue
}
var name, usage string
if i := strings.Index(tag, ","); i != -1 {
name = tag[:i]
usage = tag[i+1:]
} else {
name = tag
}
// TODO(bmizerany): add more types as needed
switch f.Kind() {
case reflect.String:
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
case reflect.Bool:
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
default:
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
}
}
return fs
}

View File

@@ -1,97 +0,0 @@
// Gguf is a tool for learning about GGUF files.
//
// Usage:
//
// gguf [flags] <file>
package main
import (
"flag"
"fmt"
"io"
"log"
"os"
"text/tabwriter"
"github.com/ollama/ollama/x/encoding/gguf"
)
func main() {
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
log.Fatal(err)
}
}
func Main(stdout io.Writer, args ...string) error {
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
fs.Usage = func() {
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
io.WriteString(stdout, "\n")
io.WriteString(stdout, "Usage:\n")
io.WriteString(stdout, "\n")
io.WriteString(stdout, "\tgguf [flags] <file>\n")
io.WriteString(stdout, "\n")
var numFlags int
fs.VisitAll(func(*flag.Flag) { numFlags++ })
if numFlags > 0 {
io.WriteString(stdout, "Flags:\n")
fs.PrintDefaults()
}
}
fs.Parse(args)
if fs.NArg() != 1 {
fs.Usage()
os.Exit(2)
}
file := fs.Arg(0)
f, err := os.Open(file)
if err != nil {
log.Fatal(err)
}
defer f.Close()
g, err := gguf.ReadFile(f)
if err != nil {
log.Fatal(err)
}
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
defer tw.Flush()
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
for m, err := range g.Metadata {
if err != nil {
log.Fatal(err)
}
if len(m.Values) > 5 {
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
} else {
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
}
}
var i int
var totalLayerBytes uint64
var offGPU bool
for t, err := range g.Tensors {
if err != nil {
log.Fatal(err)
}
totalLayerBytes += t.Size
if totalLayerBytes > *flagGPU {
offGPU = true
}
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
i++
}
return nil
}

View File

@@ -1,376 +0,0 @@
package gguf
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"os"
"strconv"
"strings"
"github.com/ollama/ollama/x/types/structs"
)
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
// MaxDimensions is the maximum number of dimensions a tensor can have.
const MaxDimensions uint32 = 1e6
// Errors
var (
// ErrBadMagic is returned when the magic bytes at the start of the
// file. This is useful for detecting if the file is not a gguf
// file.
ErrBadMagic = errors.New("gguf: bad magic")
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
ErrMangled = errors.New("gguf: mangled data")
)
type Type uint32
const (
TypeF32 Type = 0
TypeF16 Type = 1
TypeQ4_0 Type = 2
TypeQ4_1 Type = 3
TypeQ5_0 Type = 6
TypeQ5_1 Type = 7
TypeQ8_0 Type = 8
TypeQ8_1 Type = 9
TypeQ2_K Type = 10
TypeQ3_K Type = 11
TypeQ4_K Type = 12
TypeQ5_K Type = 13
TypeQ6_K Type = 14
TypeQ8_K Type = 15
TypeI8 Type = 16
TypeI16 Type = 17
TypeI32 Type = 18
TypeCount Type = 19
)
var typeNames = map[Type]string{
TypeF32: "F32",
TypeF16: "F16",
TypeQ4_0: "Q4_0",
TypeQ4_1: "Q4_1",
TypeQ5_0: "Q5_0",
TypeQ5_1: "Q5_1",
TypeQ8_0: "Q8_0",
TypeQ8_1: "Q8_1",
TypeQ2_K: "Q2_K",
TypeQ3_K: "Q3_K",
TypeQ4_K: "Q4_K",
TypeQ5_K: "Q5_K",
TypeQ6_K: "Q6_K",
TypeQ8_K: "Q8_K",
TypeI8: "I8",
TypeI16: "I16",
TypeI32: "I32",
TypeCount: "COUNT",
}
func (t Type) String() string {
if name := typeNames[t]; name != "" {
return name
}
return fmt.Sprintf("(!unknown_type %d!)", t)
}
// ValueType is the type of a metadata value.
type ValueType uint32
func (t ValueType) String() string {
if name := metaTypeNames[t]; name != "" {
return name
}
return fmt.Sprintf("(!unknown_value_type %d!)", t)
}
const (
ValueTypeUint8 ValueType = 0
ValueTypeInt8 ValueType = 1
ValueTypeUint16 ValueType = 2
ValueTypeInt16 ValueType = 3
ValueTypeUint32 ValueType = 4
ValueTypeInt32 ValueType = 5
ValueTypeFloat32 ValueType = 6
ValueTypeBool ValueType = 7
ValueTypeString ValueType = 8
ValueTypeArray ValueType = 9
ValueTypeUint64 ValueType = 10
ValueTypeInt64 ValueType = 11
ValueTypeFloat64 ValueType = 12
)
var metaTypeNames = map[ValueType]string{
ValueTypeUint8: "uint8",
ValueTypeInt8: "int8",
ValueTypeUint16: "uint16",
ValueTypeInt16: "int16",
ValueTypeUint32: "uint32",
ValueTypeInt32: "int32",
ValueTypeFloat32: "float32",
ValueTypeBool: "bool",
ValueTypeString: "string",
ValueTypeArray: "array",
ValueTypeUint64: "uint64",
ValueTypeInt64: "int64",
ValueTypeFloat64: "float64",
}
type TensorInfo struct {
Name string
Dimensions []uint64
Type Type
Offset uint64
Size uint64
}
type MetaValue struct {
Type ValueType
Value []byte
}
func (v MetaValue) String() string {
var b strings.Builder
b.WriteString(v.Type.String())
b.WriteString("(")
switch v.Type {
case ValueTypeArray:
b.WriteString("[...]")
case ValueTypeString:
b.WriteString(strconv.Quote(string(v.Value)))
case ValueTypeBool:
if len(v.Value) == 0 {
b.WriteString("(!invalid bool)")
}
switch v.Value[0] {
case 0:
b.WriteString("false")
case 1:
b.WriteString("true")
default:
b.WriteString("!invalid bool")
}
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
var buf [8]byte
if len(v.Value) < 8 {
copy(buf[:], v.Value)
}
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
default:
fmt.Fprintf(&b, "%v", v.Value)
}
b.WriteString(")")
return b.String()
}
type MetaEntry struct {
Key string
Type ValueType
Values []MetaValue
}
func (e MetaEntry) String() string {
if len(e.Values) == 0 {
return ""
}
return string(e.Values[0].Value)
}
func (e MetaEntry) Uint32() uint32 {
if len(e.Values) == 0 {
return 0
}
return binary.LittleEndian.Uint32(e.Values[0].Value)
}
func (e MetaEntry) FileType() Type {
if len(e.Values) == 0 {
return TypeCount
}
return Type(e.Uint32())
}
func (e MetaEntry) GoString() string {
var b strings.Builder
b.WriteString(e.Key)
b.WriteString(": ")
b.WriteString(e.Type.String())
b.WriteString("(")
for i, v := range e.Values {
if i > 0 {
b.WriteString(", ")
}
b.WriteString(v.String())
}
b.WriteString(")")
return b.String()
}
type Info struct {
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
Version int
FileType Type
}
func Stat(path string) (Info, error) {
f, err := os.Open(path)
if err != nil {
return Info{}, err
}
defer f.Close()
return StatReader(f)
}
// StatReader reads the header information from r and returns an Info
// struct with the version and file type.
//
// It returns an error if any.
//
// As a special case, it returns ErrBadMagic if the file does not start with
// the magic bytes. This can be used to detect if the file is not a GGUF
// file.
func StatReader(r io.ReadSeeker) (Info, error) {
if _, err := r.Seek(0, 0); err != nil {
return Info{}, err
}
f, err := ReadFile(r)
if err != nil {
return Info{}, err
}
info := Info{Version: f.Version()}
for m, err := range f.Metadata {
if err != nil {
return Info{}, err
}
if m.Key == "general.file_type" {
if m.Type != ValueTypeUint32 {
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
}
info.FileType = m.FileType()
}
}
return info, nil
}
type File struct {
version uint32
numMetaValues uint64
numTensors uint64
gr *ggufReader
}
// ReadFile reads header information from r and returns a File, ready for
// iteration over Metadata and Tensors.
func ReadFile(r io.Reader) (*File, error) {
f, err := readFile(r)
if err != nil {
return nil, err
}
return f, nil
}
func (f *File) Version() int {
return int(f.version)
}
// Metadata iterates over the metadata in the file. It must be exhausted
// before calling Tensors.
//
// It is not resumable.
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
var n int
for range f.numMetaValues {
meta, err := f.gr.readMetaEntry()
if err != nil {
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
yield(MetaEntry{}, err)
return
}
if !yield(meta, nil) {
return
}
n++
}
}
// Tensors iterates over the tensors in the file. It must only be called
// after exhausting the metadata iterator.
//
// It is not resumable.
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
var last TensorInfo
for range f.numTensors {
info, err := f.gr.readTensorInfo()
// If the last tensor had a valid offset, yield it.
//
// NOTE: No tensor should have an offset of 0 because the
// offset is the start of the tensor data which is always
// afer the magic bytes, version, numMetaValues, and
// numTensors, which MUST all be non-zero bytes as per the
// GGUF spec.
if last.Offset > 0 {
if !yield(last, err) {
return
}
}
if err != nil {
yield(TensorInfo{}, err)
return
}
// Tensor data does not include size, so we need to
// calculate it based on the offset of the previous tensor
// offset to the current.
offset0 := last.Offset
last = info
last.Size = info.Offset - offset0
}
if last.Offset > 0 {
yield(last, nil)
}
}
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
func readFile(r io.Reader) (*File, error) {
gr := &ggufReader{r: &reader{r: r}}
magic, err := gr.next(4)
if err != nil {
return nil, errors.Join(err, ErrBadMagic)
}
if !bytes.Equal(magic, magicBytes) {
return nil, ErrBadMagic
}
version, err := gr.readUint32()
if err != nil {
return nil, err
}
if version != 3 {
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
}
numTensors, err := gr.readUint64()
if err != nil {
return nil, err
}
numMetaValues, err := gr.readUint64()
if err != nil {
return nil, err
}
info := &File{
version: version,
numMetaValues: numMetaValues,
numTensors: numTensors,
gr: gr,
}
return info, nil
}

View File

@@ -1,345 +0,0 @@
package gguf
import (
"errors"
"io"
"strings"
"testing"
"kr.dev/diff"
)
func TestStat(t *testing.T) {
cases := []struct {
name string
data string
wantInfo Info
wantErr error
}{
{
name: "empty",
wantErr: ErrBadMagic,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantErr: ErrUnsupportedVersion,
},
{
name: "valid general.file_type",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// general.file_type key
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
"general.file_type" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
"",
wantInfo: Info{
Version: 3,
FileType: 1,
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
info, err := StatReader(strings.NewReader(tt.data))
if tt.wantErr != nil {
if !errors.Is(err, tt.wantErr) {
t.Fatalf("err = %v; want %q", err, tt.wantErr)
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
diff.Test(t, t.Errorf, info, tt.wantInfo)
})
}
}
func TestReadInfo(t *testing.T) {
cases := []struct {
name string
data string
wantMeta []MetaEntry
wantTensor []TensorInfo
wantReadErr error
wantMetaErr error
wantTensorErr error
wantInfo Info
}{
{
name: "empty",
wantReadErr: io.ErrUnexpectedEOF,
},
{
name: "bad magic",
data: "\xBB\xAA\xDD\x00",
wantReadErr: ErrBadMagic,
},
{
name: "bad version",
data: string(magicBytes) +
"\x02\x00\x00\x00" + // version
"",
wantReadErr: ErrUnsupportedVersion,
},
{
name: "no metadata or tensors",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"",
wantReadErr: nil,
},
{
name: "good metadata",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMeta: []MetaEntry{
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
},
},
{
name: "good metadata with multiple values",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// MetaEntry 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"x" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"XX" + // string value
// MetaEntry 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"y" + // key
"\x04\x00\x00\x00" + // type (uint32)
"\x99\x88\x77\x66" + // uint32 value
"",
wantMeta: []MetaEntry{
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
Type: ValueTypeString,
Value: []byte("XX"),
}}},
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
Type: ValueTypeUint32,
Value: []byte{0x99, 0x88, 0x77, 0x66},
}}},
},
},
{
name: "negative string length in meta key",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"",
wantMetaErr: ErrMangled,
},
// Tensor tests
{
name: "good tensor",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
// dimensions
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
},
},
{
name: "too many dimensions",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x00\x00\x00\x01" + // dimensions length
"",
wantTensorErr: ErrMangled,
},
{
name: "size computed",
data: string(magicBytes) + // magic
"\x03\x00\x00\x00" + // version
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
// Tensor 1
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
// Tensor 2
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
"",
wantTensor: []TensorInfo{
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 256,
Size: 256,
},
{
Name: "t",
Dimensions: []uint64{1},
Type: TypeQ4_1,
Offset: 768,
Size: 512,
},
},
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
f, err := ReadFile(strings.NewReader(tt.data))
if err != nil {
if !errors.Is(err, tt.wantReadErr) {
t.Fatalf("unexpected ReadFile error: %v", err)
}
return
}
var got []MetaEntry
for meta, err := range f.Metadata {
if !errors.Is(err, tt.wantMetaErr) {
t.Fatalf("err = %v; want %v", err, ErrMangled)
}
if err != nil {
return
}
got = append(got, meta)
}
diff.Test(t, t.Errorf, got, tt.wantMeta)
var gotT []TensorInfo
for tinfo, err := range f.Tensors {
if !errors.Is(err, tt.wantTensorErr) {
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
}
if err != nil {
return
}
gotT = append(gotT, tinfo)
}
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
})
}
}
func FuzzReadInfo(f *testing.F) {
f.Add(string(magicBytes))
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"")
f.Add(string(magicBytes) +
"\x03\x00\x00\x00" + // version
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
"K" + // key
"\x08\x00\x00\x00" + // type (string)
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
"VV" + // string value
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
"t" +
"\x01\x00\x00\x00" + // dimensions length
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
"\x03\x00\x00\x00" + // type (i8)
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
"")
f.Fuzz(func(t *testing.T, data string) {
gf, err := ReadFile(strings.NewReader(data))
if err != nil {
t.Logf("ReadFile error: %v", err)
t.Skip()
}
for _, err := range gf.Metadata {
if err != nil {
t.Logf("metadata error: %v", err)
t.Skip()
}
}
for tinfo, err := range gf.Tensors {
if err != nil {
t.Logf("tensor error: %v", err)
t.Skip()
}
if tinfo.Offset <= 0 {
t.Logf("invalid tensor offset: %+v", t)
t.Skip()
}
if tinfo.Size <= 0 {
t.Logf("invalid tensor size: %+v", t)
t.Skip()
}
}
})
}

View File

@@ -1,195 +0,0 @@
package gguf
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"iter"
)
type ggufReader struct {
r *reader
n int
}
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
key, err := r.readString()
if err != nil {
return MetaEntry{}, err
}
typ, err := r.readValueType()
if err != nil {
return MetaEntry{}, err
}
var values []MetaValue
for v, err := range r.readMetaValues(typ) {
if err != nil {
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
return MetaEntry{}, err
}
values = append(values, v)
}
return MetaEntry{
Key: string(key),
Type: typ,
Values: values,
}, nil
}
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
var value []byte
var err error
switch typ {
case ValueTypeUint8, ValueTypeInt8:
value, err = r.next(1)
case ValueTypeUint16, ValueTypeInt16:
value, err = r.next(2)
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
value, err = r.next(4)
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
value, err = r.next(8)
case ValueTypeBool:
value, err = r.next(1)
case ValueTypeString:
value, err = r.readString()
case ValueTypeArray:
err = fmt.Errorf("nested arrays are not supported")
default:
err = fmt.Errorf("unsupported metadata type: %d", typ)
}
if err != nil {
return MetaValue{}, err
}
return MetaValue{
Type: typ,
Value: bytes.Clone(value),
}, nil
}
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
return func(yield func(MetaValue, error) bool) {
if typ == ValueTypeArray {
atyp, err := r.readValueType()
if err != nil {
err = fmt.Errorf("invalid type: %w", err)
yield(MetaValue{}, err)
return
}
n, err := r.readUint64()
if err != nil {
err = fmt.Errorf("invalid length: %w", err)
yield(MetaValue{}, err)
return
}
for i := range n {
v, err := r.readMetaValue(atyp)
if err != nil {
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
yield(MetaValue{}, err)
return
}
if !yield(v, nil) {
return
}
}
} else {
v, err := r.readMetaValue(typ)
if err != nil {
err = fmt.Errorf("error reading metadata value: %w", err)
yield(MetaValue{}, err)
return
}
yield(v, nil)
}
}
}
func (r *ggufReader) readValueType() (ValueType, error) {
typ, err := r.readUint32()
return ValueType(typ), err
}
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
name, err := r.readString()
if err != nil {
return TensorInfo{}, err
}
numDimensions, err := r.readUint32()
if err != nil {
return TensorInfo{}, err
}
if numDimensions > MaxDimensions {
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
}
dims := make([]uint64, numDimensions)
for i := range dims {
d, err := r.readUint64()
if err != nil {
return TensorInfo{}, err
}
dims[i] = d
}
typ, err := r.readUint32()
if err != nil {
return TensorInfo{}, err
}
offset, err := r.readUint64()
if err != nil {
return TensorInfo{}, err
}
// TODO(bmizerany): check offset is multiple of ALIGNMENT
return TensorInfo{
Name: string(name),
Dimensions: dims,
Type: Type(typ),
Offset: offset,
}, nil
}
func (r *ggufReader) next(n int) ([]byte, error) {
if n < 0 {
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
}
w := r.r.window()
for len(w) < n {
if r.r.extend() == 0 {
return nil, io.ErrUnexpectedEOF
}
w = r.r.window()
}
r.r.release(n)
r.n += n
return w[:n], nil
}
func (r *ggufReader) readString() ([]byte, error) {
n, err := r.readUint64()
if err != nil {
return nil, err
}
// TODO(bmizerany): limit max string length
return r.next(int(n))
}
func (r *ggufReader) readUint32() (uint32, error) {
b, err := r.next(4)
if err != nil {
return 0, err
}
n := binary.LittleEndian.Uint32(b)
return n, nil
}
func (r *ggufReader) readUint64() (uint64, error) {
b, err := r.next(8)
if err != nil {
return 0, err
}
n := binary.LittleEndian.Uint64(b)
return n, nil
}

View File

@@ -1,70 +0,0 @@
package gguf
import "io"
// A reader implements a sliding window over an io.Reader.
type reader struct {
data []byte
offset int
r io.Reader
err error
}
// release discards n bytes from the front of the window.
func (b *reader) release(n int) {
b.offset += n
}
// window returns the current window.
// The window is invalidated by calls to release or extend.
func (b *reader) window() []byte {
return b.data[b.offset:]
}
// tuning constants for byteReader.extend.
const (
newBufferSize = 8 << 10
minReadSize = newBufferSize >> 2
)
// extend extends the window with data from the underlying reader.
func (b *reader) extend() int {
if b.err != nil {
return 0
}
remaining := len(b.data) - b.offset
if remaining == 0 {
b.data = b.data[:0]
b.offset = 0
}
if cap(b.data)-len(b.data) >= minReadSize {
// nothing to do, enough space exists between len and cap.
} else if cap(b.data)-remaining >= minReadSize {
// buffer has enough space if we move the data to the front.
b.compact()
} else {
// otherwise, we must allocate/extend a new buffer
b.grow()
}
remaining += b.offset
n, err := b.r.Read(b.data[remaining:cap(b.data)])
// reduce length to the existing plus the data we read.
b.data = b.data[:remaining+n]
b.err = err
return n
}
// grow grows the buffer, moving the active data to the front.
func (b *reader) grow() {
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
copy(buf, b.data[b.offset:])
b.data = buf
b.offset = 0
}
// compact moves the active data to the front of the buffer.
func (b *reader) compact() {
copy(b.data, b.data[b.offset:])
b.offset = 0
}

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")

View File

@@ -1,134 +0,0 @@
package model
import (
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log/slog"
"strings"
"unicode"
)
// Digest represents a digest of a model Manifest. It is a comparable value
// type and is immutable.
//
// The zero Digest is not a valid digest.
type Digest struct {
s string
}
// Type returns the digest type of the digest.
//
// Example:
//
// ParseDigest("sha256-1234").Type() // returns "sha256"
func (d Digest) Type() string {
typ, _, _ := strings.Cut(d.s, "-")
return typ
}
// String returns the digest in the form of "<digest-type>-<digest>", or the
// empty string if the digest is invalid.
func (d Digest) String() string { return d.s }
// IsValid returns true if the digest is valid (not zero).
//
// A valid digest may be created only by ParseDigest, or
// ParseName(name).Digest().
func (d Digest) IsValid() bool { return d.s != "" }
// MarshalText implements encoding.TextMarshaler.
func (d Digest) MarshalText() ([]byte, error) {
return []byte(d.String()), nil
}
// UnmarshalText implements encoding.TextUnmarshaler.
func (d *Digest) UnmarshalText(text []byte) error {
if d.IsValid() {
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
}
*d = ParseDigest(string(text))
return nil
}
// LogValue implements slog.Value.
func (d Digest) LogValue() slog.Value {
return slog.StringValue(d.String())
}
var (
_ driver.Valuer = Digest{}
_ sql.Scanner = (*Digest)(nil)
_ slog.LogValuer = Digest{}
)
// Scan implements the sql.Scanner interface.
func (d *Digest) Scan(src any) error {
if d.IsValid() {
return errors.New("model.Digest: illegal Scan on valid Digest")
}
switch v := src.(type) {
case string:
*d = ParseDigest(v)
return nil
case []byte:
*d = ParseDigest(string(v))
return nil
}
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
}
// Value implements the driver.Valuer interface.
func (d Digest) Value() (driver.Value, error) {
return d.String(), nil
}
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
// Digest.
func ParseDigest(s string) Digest {
typ, digest, ok := strings.Cut(s, "-")
if ok && isValidDigestType(typ) && isValidHex(digest) {
return Digest{s: s}
}
return Digest{}
}
// isValidDigest returns true if the given string in the form of
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
// and <digest> is a valid hex string.
//
// It does not check if the digest is a valid hash for the given digest
// type, or restrict the digest type to a known set of types. This is left
// up to ueers of this package.
func isValidDigest(s string) bool {
typ, digest, ok := strings.Cut(s, "-")
res := ok && isValidDigestType(typ) && isValidHex(digest)
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
return res
}
func isValidDigestType(s string) bool {
if len(s) == 0 {
return false
}
for _, r := range s {
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
return false
}
}
return true
}
func isValidHex(s string) bool {
if len(s) == 0 {
return false
}
for i := range s {
c := s[i]
if c < '0' || c > '9' && c < 'a' || c > 'f' {
return false
}
}
return true
}

View File

@@ -1,83 +0,0 @@
package model
import "testing"
// - test scan
// - test marshal text
// - test unmarshal text
// - test log value
// - test string
// - test type
// - test digest
// - test valid
// - test driver valuer
// - test sql scanner
// - test parse digest
var testDigests = map[string]Digest{
"": {},
"sha256-1234": {s: "sha256-1234"},
"sha256-5678": {s: "sha256-5678"},
"blake2-9abc": {s: "blake2-9abc"},
"-1234": {},
"sha256-": {},
"sha256-1234-5678": {},
"sha256-P": {}, // invalid hex
"sha256-1234P": {},
"---": {},
}
func TestDigestParse(t *testing.T) {
// Test cases.
for s, want := range testDigests {
got := ParseDigest(s)
t.Logf("ParseDigest(%q) = %#v", s, got)
if got != want {
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
}
}
}
func TestDigestString(t *testing.T) {
// Test cases.
for s, d := range testDigests {
want := s
if !d.IsValid() {
want = ""
}
got := d.String()
if got != want {
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
}
got = ParseDigest(s).String()
if got != want {
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
}
}
}
func TestDigestUnmarshalText(t *testing.T) {
const testDigest = "sha256-1234"
t.Run("UnmarshalText (into Valid)", func(t *testing.T) {
d := ParseDigest(testDigest)
if !d.IsValid() {
panic("invalid test")
}
if err := d.UnmarshalText(nil); err == nil {
t.Errorf("UnmarshalText on valid Digest did not return error")
}
if d.String() != testDigest {
t.Errorf("UnmarshalText on valid Digest changed Digest: %q", d.String())
}
})
t.Run("UnmarshalText make safe copy", func(t *testing.T) {
data := []byte(testDigest)
var d Digest
d.UnmarshalText(data)
data[0] = 'x'
if d.String() != testDigest {
t.Errorf("UnmarshalText did not make a safe copy")
}
})
}

View File

@@ -1,132 +0,0 @@
// Package model implements the File and Name types for working with and
// representing Modelfiles and model Names.
//
// The Name type should be used when working with model names, and the File
// type should be used when working with Modelfiles.
package model
import (
"bufio"
"io"
"iter"
"strings"
)
type ParamPragma struct {
Key string
Value string
}
type MessagePragma struct {
Role string
Content string
}
type File struct {
// From is a required pragma that specifies the source of the model,
// either on disk, or by reference (see model.ParseName).
From string
// Optional
Params []ParamPragma
Template string
System string
Adapter string
Messages []MessagePragma
License string
}
type FileError struct {
Pragma string
Message string
}
func (e *FileError) Error() string {
return e.Pragma + ": " + e.Message
}
// Pragma represents a single pragma in a Modelfile.
type Pragma struct {
// The pragma name
Name string
// Args contains the user-defined arguments for the pragma. If no
// arguments were provided, it is nil.
Args []string
}
func (p Pragma) Arg(i int) string {
if i >= len(p.Args) {
return ""
}
return p.Args[i]
}
func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
return func(yield func(Pragma, error) bool) {
sc := bufio.NewScanner(r)
for sc.Scan() {
line := sc.Text()
// TODO(bmizerany): set a max num fields/args to
// prevent mem bloat
args := strings.Fields(line)
if len(args) == 0 {
continue
}
p := Pragma{
Name: strings.ToUpper(args[0]),
}
if p.Name == "MESSAGE" {
// handle special case where message content
// is space separated on the _rest_ of the
// line like: `MESSAGE user Is Ontario in
// Canada?`
panic("TODO")
}
if len(args) > 1 {
p.Args = args[1:]
}
if !yield(p, nil) {
return
}
}
if sc.Err() != nil {
yield(Pragma{}, sc.Err())
}
}
}
func ParseFile(r io.Reader) (File, error) {
var f File
for p, err := range FilePragmas(r) {
if err != nil {
return File{}, err
}
switch p.Name {
case "FROM":
f.From = p.Arg(0)
case "PARAMETER":
f.Params = append(f.Params, ParamPragma{
Key: strings.ToLower(p.Arg(0)),
Value: p.Arg(1),
})
case "TEMPLATE":
f.Template = p.Arg(0)
case "SYSTEM":
f.System = p.Arg(0)
case "ADAPTER":
f.Adapter = p.Arg(0)
case "MESSAGE":
f.Messages = append(f.Messages, MessagePragma{
Role: p.Arg(0),
Content: p.Arg(1),
})
case "LICENSE":
f.License = p.Arg(0)
}
}
return f, nil
}

View File

@@ -1,593 +0,0 @@
package model
import (
"bytes"
"cmp"
"database/sql"
"database/sql/driver"
"errors"
"hash/maphash"
"io"
"iter"
"log/slog"
"slices"
"strings"
"sync"
"github.com/ollama/ollama/x/types/structs"
)
// Errors
var (
// ErrInvalidName is not used by this package, but is exported so that
// other packages do not need to invent their own error type when they
// need to return an error for an invalid name.
ErrIncompleteName = errors.New("incomplete model name")
ErrInvalidDigest = errors.New("invalid digest")
)
const MaxNamePartLen = 128
type PartKind int
// Levels of concreteness
const (
// Each value aligns with its index in the Name.parts array.
PartHost PartKind = iota
PartNamespace
PartModel
PartTag
PartBuild
PartDigest
// Invalid is a special part that is used to indicate that a part is
// invalid. It is not a valid part of a Name.
//
// It should be kept as the last part in the list.
PartInvalid
)
var kindNames = map[PartKind]string{
PartHost: "Host",
PartNamespace: "Namespace",
PartModel: "Name",
PartTag: "Tag",
PartBuild: "Build",
PartDigest: "Digest",
PartInvalid: "Invalid",
}
func (k PartKind) String() string {
return cmp.Or(kindNames[k], "Unknown")
}
// Name is an opaque reference to a model. It holds the parts of a model
// with the case preserved, but is not directly comparable with other Names
// since model names can be represented with different caseing depending on
// the use case. For instance, "Mistral" and "mistral" are the same model
// but each version may have come from different sources (e.g. copied from a
// Web page, or from a file path).
//
// Valid Names can ONLY be constructed by calling [ParseName].
//
// A Name is valid if and only if is have a valid Model part. The other parts
// are optional.
//
// A Name is considered "complete" if it has all parts present. To check if a
// Name is complete, use [Name.IsComplete].
//
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
//
// The parts of a Name are:
//
// - Host: the domain of the model (optional)
// - Namespace: the namespace of the model (optional)
// - Model: the name of the model (required)
// - Tag: the tag of the model (optional)
// - Build: the build of the model; usually the quantization or "file type" (optional)
//
// The parts can be obtained in their original form by calling [Name.Parts].
//
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
//
// To make a Name by filling in missing parts from another Name, use [Fill].
type Name struct {
_ structs.Incomparable
parts [6]string // host, namespace, model, tag, build
// TODO(bmizerany): track offsets and hold s (raw string) here? We
// could pack the offests all into a single uint64 since the first
// parts take less bits since their max offset is less than the max
// offset of the next part. This would save a ton of bytes per Name
// and mean zero allocations for String.
}
// ParseName parses s into a Name. The input string must be a valid string
// representation of a model name in the form:
//
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
//
// The name part is required, all others are optional. If a part is missing,
// it is left empty in the returned Name. If a part is invalid, the zero Ref
// value is returned.
//
// The build part is normalized to uppercase.
//
// Examples of valid paths:
//
// "example.com/library/mistral:7b+x"
// "example.com/eva/mistral:7b+Q4_0"
// "mistral:7b+x"
// "example.com/mike/mistral:latest+Q4_0"
// "example.com/bruce/mistral:latest"
// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
//
// Examples of invalid paths:
//
// "example.com/mistral:7b+"
// "example.com/mistral:7b+Q4_0+"
// "x/y/z/z:8n+I"
// ""
//
// It returns the zero value if any part is invalid.
//
// As a rule of thumb, an valid name is one that can be round-tripped with
// the [Name.String] method. That means ("x+") is invalid because
// [Name.String] will not print a "+" if the build is empty.
func ParseName(s string) Name {
var r Name
for kind, part := range Parts(s) {
if kind == PartInvalid {
return Name{}
}
if kind == PartDigest && !ParseDigest(part).IsValid() {
return Name{}
}
r.parts[kind] = part
}
if r.IsValid() || r.IsResolved() {
return r
}
return Name{}
}
func MustParseName(s string) Name {
r := ParseName(s)
if !r.IsValid() {
panic("model.MustParseName: invalid name: " + s)
}
return r
}
// Fill fills in the missing parts of dst with the parts of src.
//
// The returned Name will only be valid if dst is valid.
func Fill(dst, src Name) Name {
var r Name
for i := range r.parts {
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
}
return r
}
// WithBuild returns a copy of r with the build set to the given string.
func (r Name) WithBuild(build string) Name {
r.parts[PartBuild] = build
return r
}
func (r Name) WithDigest(digest Digest) Name {
r.parts[PartDigest] = digest.String()
return r
}
var mapHashSeed = maphash.MakeSeed()
// MapHash returns a case insensitive hash for use in maps and equality
// checks. For a convienent way to compare names, use [Name.EqualFold].
func (r Name) MapHash() uint64 {
// correctly hash the parts with case insensitive comparison
var h maphash.Hash
h.SetSeed(mapHashSeed)
for _, part := range r.Parts() {
// downcase the part for hashing
for i := range part {
c := part[i]
if c >= 'A' && c <= 'Z' {
c = c - 'A' + 'a'
}
h.WriteByte(c)
}
}
return h.Sum64()
}
func (r Name) slice(from, to PartKind) Name {
var v Name
copy(v.parts[from:to+1], r.parts[from:to+1])
return v
}
// DisplayModel returns the a display string composed of the model only.
func (r Name) DisplayModel() string {
return r.parts[PartModel]
}
// DisplayFullest returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>
//
// If any part is missing, it is omitted from the display string.
//
// It does not include the build part. For the fullest possible display
// string with the build, use [Name.String].
func (r Name) DisplayFullest() string {
return r.slice(PartHost, PartTag).String()
}
// DisplayShort returns the fullest possible display string in form:
//
// <model>:<tag>
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayShort() string {
return r.slice(PartModel, PartTag).String()
}
// DisplayLong returns the fullest possible display string in form:
//
// <namespace>/<model>:<tag>
//
// If any part is missing, it is omitted from the display string.
func (r Name) DisplayLong() string {
return r.slice(PartNamespace, PartTag).String()
}
var seps = [...]string{
PartHost: "/",
PartNamespace: "/",
PartModel: ":",
PartTag: "+",
PartBuild: "@",
PartDigest: "",
}
// WriteTo implements io.WriterTo. It writes the fullest possible display
// string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
//
// Missing parts and their seperators are not written.
//
// The full digest is always prefixed with "@". That is if [Name.IsValid]
// reports false and [Name.IsResolved] reports true, then the string is
// returned as "@<digest-type>-<digest>".
func (r Name) writeTo(w io.StringWriter) {
var partsWritten int
for i := range r.parts {
if r.parts[i] == "" {
continue
}
if partsWritten > 0 || i == int(PartDigest) {
w.WriteString(seps[i-1])
}
w.WriteString(r.parts[i])
partsWritten++
}
}
var builderPool = sync.Pool{
New: func() interface{} {
return &strings.Builder{}
},
}
// String returns the fullest possible display string in form:
//
// <host>/<namespace>/<model>:<tag>+<build>
//
// If any part is missing, it is omitted from the display string.
//
// For the fullest possible display string without the build, use
// [Name.DisplayFullest].
func (r Name) String() string {
b := builderPool.Get().(*strings.Builder)
defer builderPool.Put(b)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
r.writeTo(b)
return b.String()
}
// GoString implements fmt.GoStringer. It returns a string suitable for
// debugging and logging. It is similar to [Name.String] but it always
// returns a string that includes all parts of the Name, with missing parts
// replaced with a ("?").
func (r Name) GoString() string {
for i := range r.parts {
r.parts[i] = cmp.Or(r.parts[i], "?")
}
return r.String()
}
// LogValue implements slog.Valuer.
func (r Name) LogValue() slog.Value {
return slog.StringValue(r.GoString())
}
var bufPool = sync.Pool{
New: func() interface{} {
return new(bytes.Buffer)
},
}
// MarshalText implements [encoding.TextMarshaler].
func (r Name) MarshalText() ([]byte, error) {
b := bufPool.Get().(*bytes.Buffer)
b.Reset()
b.Grow(50) // arbitrarily long enough for most names
defer bufPool.Put(b)
r.writeTo(b)
// TODO: We can remove this alloc if/when
// https://github.com/golang/go/issues/62384 lands.
return b.Bytes(), nil
}
// UnmarshalText implements [encoding.TextUnmarshaler].
//
// It is an error to call UnmarshalText on a valid Name.
func (r *Name) UnmarshalText(text []byte) error {
if r.IsValid() {
// The invariant of UnmarshalText is that it should only be
// called on an invalid/zero Name. If we allow UnmarshalText
// on a valid Name, then the Name will be mutated, breaking
// the immutability of the Name.
return errors.New("model.Name: illegal UnmarshalText on valid Name")
}
// The contract of UnmarshalText is that we copy to keep the text.
*r = ParseName(string(text))
return nil
}
var (
_ driver.Valuer = Name{}
_ sql.Scanner = (*Name)(nil)
)
// Scan implements [database/sql.Scanner].
func (r *Name) Scan(src any) error {
if r.IsValid() {
// The invariant of Scan is that it should only be called on an
// invalid/zero Name. If we allow Scan on a valid Name, then the
// Name will be mutated, breaking the immutability of the Name.
return errors.New("model.Name: illegal Scan on valid Name")
}
switch v := src.(type) {
case string:
*r = ParseName(v)
return nil
case []byte:
*r = ParseName(string(v))
return nil
}
return errors.New("model.Name: invalid Scan source")
}
// Value implements [database/sql/driver.Valuer].
func (r Name) Value() (driver.Value, error) {
return r.String(), nil
}
// IsComplete reports whether the Name is fully qualified. That is it has a
// domain, namespace, name, tag, and build.
func (r Name) IsComplete() bool {
return !slices.Contains(r.parts[:PartDigest], "")
}
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
// build part to be present.
func (r Name) IsCompleteNoBuild() bool {
return !slices.Contains(r.parts[:PartBuild], "")
}
// IsResolved reports true if the Name has a valid digest.
//
// It is possible to have a valid Name, or a complete Name that is not
// resolved.
func (r Name) IsResolved() bool {
return r.Digest().IsValid()
}
// Digest returns the digest part of the Name, if any.
//
// If Digest returns a non-empty string, then [Name.IsResolved] will return
// true, and digest is considered valid.
func (r Name) Digest() Digest {
// This was already validated by ParseName, so we can just return it.
return Digest{r.parts[PartDigest]}
}
// EqualFold reports whether r and o are equivalent model names, ignoring
// case.
func (r Name) EqualFold(o Name) bool {
return r.CompareFold(o) == 0
}
// CompareFold performs a case-insensitive cmp.Compare on r and o.
//
// This can be used with [slices.SortFunc].
//
// For simple equality checks, use [Name.EqualFold].
func (r Name) CompareFold(o Name) int {
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
}
func compareFold(a, b string) int {
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
return cmp.Compare(downcase(a), downcase(b))
})
}
func downcase(r rune) rune {
if r >= 'A' && r <= 'Z' {
return r - 'A' + 'a'
}
return r
}
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
// Parts returns the parts of the Name in order of concreteness.
//
// The length of the returned slice is always 5.
func (r Name) Parts() []string {
return slices.Clone(r.parts[:])
}
// Parts returns a sequence of the parts of a Name string from most specific
// to least specific.
//
// It normalizes the input string by removing "http://" and "https://" only.
// No other normalization is done.
func Parts(s string) iter.Seq2[PartKind, string] {
return func(yield func(PartKind, string) bool) {
if strings.HasPrefix(s, "http://") {
s = s[len("http://"):]
}
if strings.HasPrefix(s, "https://") {
s = s[len("https://"):]
}
if len(s) > MaxNamePartLen || len(s) == 0 {
return
}
yieldValid := func(kind PartKind, part string) bool {
if !isValidPart(kind, part) {
yield(PartInvalid, "")
return false
}
return yield(kind, part)
}
partLen := 0
state, j := PartDigest, len(s)
for i := len(s) - 1; i >= 0; i-- {
if partLen++; partLen > MaxNamePartLen {
// catch a part that is too long early, so
// we don't keep spinning on it, waiting for
// an isInValidPart check which would scan
// over it again.
yield(PartInvalid, "")
return
}
switch s[i] {
case '@':
switch state {
case PartDigest:
if !yieldValid(PartDigest, s[i+1:j]) {
return
}
if i == 0 {
// This is the form
// "@<digest>" which is valid.
//
// We're done.
return
}
state, j, partLen = PartBuild, i, 0
default:
yield(PartInvalid, "")
return
}
case '+':
switch state {
case PartBuild, PartDigest:
if !yieldValid(PartBuild, s[i+1:j]) {
return
}
state, j, partLen = PartTag, i, 0
default:
yield(PartInvalid, "")
return
}
case ':':
switch state {
case PartTag, PartBuild, PartDigest:
if !yieldValid(PartTag, s[i+1:j]) {
return
}
state, j, partLen = PartModel, i, 0
default:
yield(PartInvalid, "")
return
}
case '/':
switch state {
case PartModel, PartTag, PartBuild, PartDigest:
if !yieldValid(PartModel, s[i+1:j]) {
return
}
state, j = PartNamespace, i
case PartNamespace:
if !yieldValid(PartNamespace, s[i+1:j]) {
return
}
state, j, partLen = PartHost, i, 0
default:
yield(PartInvalid, "")
return
}
default:
if !isValidByte(state, s[i]) {
yield(PartInvalid, "")
return
}
}
}
if state <= PartNamespace {
yieldValid(state, s[:j])
} else {
yieldValid(PartModel, s[:j])
}
}
}
// IsValid returns true if the Name hPartas a valid nick. To know if a Name is
// "complete", use [Name.IsComplete].
func (r Name) IsValid() bool {
// Parts ensures we only have valid parts, so no need to validate
// them here, only check if we have a name or not.
return r.parts[PartModel] != ""
}
// isValidPart returns Parttrue if given part is valid ascii [a-zA-Z0-9_\.-]
func isValidPart(kind PartKind, s string) bool {
if s == "" {
return false
}
for _, c := range []byte(s) {
if !isValidByte(kind, c) {
return false
}
}
return true
}
func isValidByte(kind PartKind, c byte) bool {
if kind == PartNamespace && c == '.' {
return false
}
if c == '.' || c == '-' {
return true
}
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
return true
}
return false
}

View File

@@ -1,572 +0,0 @@
package model
import (
"bytes"
"cmp"
"errors"
"fmt"
"log/slog"
"slices"
"strings"
"testing"
)
type fields struct {
host, namespace, model, tag, build string
digest string
}
func fieldsFromName(p Name) fields {
return fields{
host: p.parts[PartHost],
namespace: p.parts[PartNamespace],
model: p.parts[PartModel],
tag: p.parts[PartTag],
build: p.parts[PartBuild],
digest: p.parts[PartDigest],
}
}
var testNames = map[string]fields{
"mistral:latest": {model: "mistral", tag: "latest"},
"mistral": {model: "mistral"},
"mistral:30B": {model: "mistral", tag: "30B"},
"mistral:7b": {model: "mistral", tag: "7b"},
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
"mistral+KQED": {model: "mistral", build: "KQED"},
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
"llama2": {model: "llama2"},
"user/model": {namespace: "user", model: "model"},
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
// invalid digest
"mistral:latest@invalid256-": {},
"mistral:latest@-123": {},
"mistral:latest@!-123": {},
"mistral:latest@1-!": {},
"mistral:latest@": {},
// resolved
"x@sha123-1": {model: "x", digest: "sha123-1"},
"@sha456-2": {digest: "sha456-2"},
"@@sha123-1": {},
// preserves case for build
"x+b": {model: "x", build: "b"},
// invalid (includes fuzzing trophies)
" / / : + ": {},
" / : + ": {},
" : + ": {},
" + ": {},
" : ": {},
" / ": {},
" /": {},
"/ ": {},
"/": {},
":": {},
"+": {},
// (".") in namepsace is not allowed
"invalid.com/7b+x": {},
"invalid:7b+Q4_0:latest": {},
"in valid": {},
"invalid/y/z/foo": {},
"/0": {},
"0 /0": {},
"0 /": {},
"0/": {},
":/0": {},
"+0/00000": {},
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
"0//0": {},
"m+^^^": {},
"file:///etc/passwd": {},
"file:///etc/passwd:latest": {},
"file:///etc/passwd:latest+u": {},
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
strings.Repeat("a", MaxNamePartLen+1): {},
}
func TestNameParts(t *testing.T) {
var p Name
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
t.Errorf("Parts() = %d; want %d", g, w)
}
}
func TestNamePartString(t *testing.T) {
if g := PartKind(-2).String(); g != "Unknown" {
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
}
for kind, name := range kindNames {
if g := kind.String(); g != name {
t.Errorf("%s = %q; want %q", kind, g, name)
}
}
}
func TestParseName(t *testing.T) {
for baseName, want := range testNames {
for _, prefix := range []string{"", "https://", "http://"} {
// We should get the same results with or without the
// http(s) prefixes
s := prefix + baseName
t.Run(s, func(t *testing.T) {
for kind, part := range Parts(s) {
t.Logf("Part: %s: %q", kind, part)
}
name := ParseName(s)
got := fieldsFromName(name)
if got != want {
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
}
// test round-trip
if !ParseName(name.String()).EqualFold(name) {
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
}
if name.IsValid() && name.DisplayModel() == "" {
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
} else if !name.IsValid() && name.DisplayModel() != "" {
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
}
if name.IsResolved() && !name.Digest().IsValid() {
t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
} else if !name.IsResolved() && name.Digest().IsValid() {
t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
}
})
}
}
}
func TestCompleteWithAndWithoutBuild(t *testing.T) {
cases := []struct {
in string
complete bool
completeNoBuild bool
}{
{"", false, false},
{"incomplete/mistral:7b+x", false, false},
{"incomplete/mistral:7b+Q4_0", false, false},
{"incomplete:7b+x", false, false},
{"complete.com/x/mistral:latest+Q4_0", true, true},
{"complete.com/x/mistral:latest", false, true},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in)
t.Logf("ParseName(%q) = %#v", tt.in, p)
if g := p.IsComplete(); g != tt.complete {
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
}
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
}
})
}
// Complete uses Parts which returns a slice, but it should be
// inlined when used in Complete, preventing any allocations or
// escaping to the heap.
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("complete.com/x/mistral:latest+Q4_0").IsComplete())
})
if allocs > 0 {
t.Errorf("Complete allocs = %v; want 0", allocs)
}
}
func TestNameLogValue(t *testing.T) {
cases := []string{
"example.com/library/mistral:latest+Q4_0",
"mistral:latest",
"mistral:7b+Q4_0",
}
for _, s := range cases {
t.Run(s, func(t *testing.T) {
var b bytes.Buffer
log := slog.New(slog.NewTextHandler(&b, nil))
name := ParseName(s)
log.Info("", "name", name)
want := fmt.Sprintf("name=%s", name.GoString())
got := b.String()
if !strings.Contains(got, want) {
t.Errorf("expected log output to contain %q; got %q", want, got)
}
})
}
}
func TestNameDisplay(t *testing.T) {
cases := []struct {
name string
in string
wantShort string
wantLong string
wantComplete string
wantString string
wantModel string
wantGoString string // default is tt.in
}{
{
name: "Complete Name",
in: "example.com/library/mistral:latest+Q4_0",
wantShort: "mistral:latest",
wantLong: "library/mistral:latest",
wantComplete: "example.com/library/mistral:latest",
wantModel: "mistral",
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
},
{
name: "Short Name",
in: "mistral:latest",
wantShort: "mistral:latest",
wantLong: "mistral:latest",
wantComplete: "mistral:latest",
wantModel: "mistral",
wantGoString: "?/?/mistral:latest+?@?",
},
{
name: "Long Name",
in: "library/mistral:latest",
wantShort: "mistral:latest",
wantLong: "library/mistral:latest",
wantComplete: "library/mistral:latest",
wantModel: "mistral",
wantGoString: "?/library/mistral:latest+?@?",
},
{
name: "Case Preserved",
in: "Library/Mistral:Latest",
wantShort: "Mistral:Latest",
wantLong: "Library/Mistral:Latest",
wantComplete: "Library/Mistral:Latest",
wantModel: "Mistral",
wantGoString: "?/Library/Mistral:Latest+?@?",
},
{
name: "With digest",
in: "Library/Mistral:Latest@sha256-123456",
wantShort: "Mistral:Latest",
wantLong: "Library/Mistral:Latest",
wantComplete: "Library/Mistral:Latest",
wantModel: "Mistral",
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
},
}
for _, tt := range cases {
t.Run(tt.name, func(t *testing.T) {
p := ParseName(tt.in)
if g := p.DisplayShort(); g != tt.wantShort {
t.Errorf("DisplayShort = %q; want %q", g, tt.wantShort)
}
if g := p.DisplayLong(); g != tt.wantLong {
t.Errorf("DisplayLong = %q; want %q", g, tt.wantLong)
}
if g := p.DisplayFullest(); g != tt.wantComplete {
t.Errorf("DisplayFullest = %q; want %q", g, tt.wantComplete)
}
if g := p.String(); g != tt.in {
t.Errorf("String(%q) = %q; want %q", tt.in, g, tt.in)
}
if g := p.DisplayModel(); g != tt.wantModel {
t.Errorf("Model = %q; want %q", g, tt.wantModel)
}
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
}
})
}
}
func TestParseNameAllocs(t *testing.T) {
allocs := testing.AllocsPerRun(1000, func() {
keep(ParseName("example.com/mistral:7b+Q4_0"))
})
if allocs > 0 {
t.Errorf("ParseName allocs = %v; want 0", allocs)
}
}
func BenchmarkParseName(b *testing.B) {
b.ReportAllocs()
for range b.N {
keep(ParseName("example.com/mistral:7b+Q4_0"))
}
}
func BenchmarkNameDisplay(b *testing.B) {
b.ReportAllocs()
r := ParseName("example.com/mistral:7b+Q4_0")
b.Run("Short", func(b *testing.B) {
for range b.N {
keep(r.DisplayShort())
}
})
}
func FuzzParseName(f *testing.F) {
f.Add("example.com/mistral:7b+Q4_0")
f.Add("example.com/mistral:7b+q4_0")
f.Add("example.com/mistral:7b+x")
f.Add("x/y/z:8n+I")
f.Fuzz(func(t *testing.T, s string) {
r0 := ParseName(s)
if !r0.IsValid() {
if !r0.EqualFold(Name{}) {
t.Errorf("expected invalid path to be zero value; got %#v", r0)
}
t.Skipf("invalid path: %q", s)
}
for _, p := range r0.Parts() {
if len(p) > MaxNamePartLen {
t.Errorf("part too long: %q", p)
}
}
if !strings.EqualFold(r0.String(), s) {
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
}
r1 := ParseName(r0.String())
if !r0.EqualFold(r1) {
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
}
})
}
func TestFill(t *testing.T) {
cases := []struct {
dst string
src string
want string
}{
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
}
for _, tt := range cases {
t.Run(tt.dst, func(t *testing.T) {
r := Fill(ParseName(tt.dst), ParseName(tt.src))
if r.String() != tt.want {
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
}
})
}
}
func TestNameTextMarshal(t *testing.T) {
cases := []struct {
in string
want string
wantErr error
}{
{"example.com/mistral:latest+Q4_0", "", nil},
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
{"mistral:latest", "mistral:latest", nil},
{"mistral", "mistral", nil},
{"mistral:7b", "mistral:7b", nil},
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
}
for _, tt := range cases {
t.Run(tt.in, func(t *testing.T) {
p := ParseName(tt.in)
got, err := p.MarshalText()
if !errors.Is(err, tt.wantErr) {
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
}
if string(got) != tt.want {
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
}
var r Name
if err := r.UnmarshalText(got); err != nil {
t.Fatalf("UnmarshalText() error = %v; want nil", err)
}
if !r.EqualFold(p) {
t.Errorf("UnmarshalText() = %q; want %q", r, p)
}
})
}
t.Run("UnmarshalText into valid Name", func(t *testing.T) {
// UnmarshalText should not be called on a valid Name.
p := MustParseName("x")
if err := p.UnmarshalText([]byte("mistral:latest+Q4_0")); err == nil {
t.Error("UnmarshalText() = nil; want error")
}
})
t.Run("TextMarshal allocs", func(t *testing.T) {
var data []byte
name := ParseName("example.com/ns/mistral:latest+Q4_0")
if !name.IsComplete() {
// sanity check
panic("sanity check failed")
}
allocs := testing.AllocsPerRun(1000, func() {
var err error
data, err = name.MarshalText()
if err != nil {
t.Fatal(err)
}
if len(data) == 0 {
t.Fatal("MarshalText() = 0; want non-zero")
}
})
if allocs > 0 {
// TODO: Update when/if this lands:
// https://github.com/golang/go/issues/62384
//
// Currently, the best we can do is 1 alloc.
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
}
})
t.Run("UnmarshalTest makes safe copy", func(t *testing.T) {
// UnmarshalText should make a copy of the data.
data := []byte("mistral:latest+Q4_0")
p := Name{}
if err := p.UnmarshalText(data); err != nil {
t.Fatal(err)
}
data[0] = 'x'
if p.String() != "mistral:latest+Q4_0" {
t.Errorf("UnmarshalText() did not make a copy")
}
})
}
func TestSQL(t *testing.T) {
t.Run("Scan for already valid Name", func(t *testing.T) {
p := MustParseName("x")
if err := p.Scan("mistral:latest+Q4_0"); err == nil {
t.Error("Scan() = nil; want error")
}
})
t.Run("Scan for invalid Name", func(t *testing.T) {
p := Name{}
if err := p.Scan("mistral:latest+Q4_0"); err != nil {
t.Errorf("Scan() = %v; want nil", err)
}
if p.String() != "mistral:latest+Q4_0" {
t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
}
})
t.Run("Value", func(t *testing.T) {
p := MustParseName("x")
if g, err := p.Value(); err != nil {
t.Errorf("Value() error = %v; want nil", err)
} else if g != "x" {
t.Errorf("Value() = %q; want %q", g, "x")
}
})
}
func TestNameStringAllocs(t *testing.T) {
name := ParseName("example.com/ns/mistral:latest+Q4_0")
allocs := testing.AllocsPerRun(1000, func() {
keep(name.String())
})
if allocs > 1 {
t.Errorf("String allocs = %v; want 0", allocs)
}
}
func ExampleFill() {
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
r := Fill(ParseName("mistral"), defaults)
fmt.Println(r)
// Output:
// registry.ollama.com/library/mistral:latest+Q4_0
}
func ExampleName_MapHash() {
m := map[uint64]bool{}
// key 1
m[ParseName("mistral:latest+q4").MapHash()] = true
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
m[ParseName("mistral:LATest+Q4").MapHash()] = true
// key 2
m[ParseName("mistral:LATest").MapHash()] = true
fmt.Println(len(m))
// Output:
// 2
}
func ExampleName_CompareFold_sort() {
names := []Name{
ParseName("mistral:latest"),
ParseName("mistRal:7b+q4"),
ParseName("MIstral:7b"),
}
slices.SortFunc(names, Name.CompareFold)
for _, n := range names {
fmt.Println(n)
}
// Output:
// MIstral:7b
// mistRal:7b+q4
// mistral:latest
}
func ExampleName_completeAndResolved() {
for _, s := range []string{
"x/y/z:latest+q4_0@sha123-1",
"x/y/z:latest+q4_0",
"@sha123-1",
} {
p := ParseName(s)
fmt.Printf("complete:%v resolved:%v digest:%s\n", p.IsComplete(), p.IsResolved(), p.Digest())
}
// Output:
// complete:true resolved:true digest:sha123-1
// complete:true resolved:false digest:
// complete:false resolved:true digest:sha123-1
}
func ExampleName_DisplayFullest() {
for _, s := range []string{
"example.com/jmorganca/mistral:latest+Q4_0",
"mistral:latest+Q4_0",
"mistral:latest",
} {
fmt.Println(ParseName(s).DisplayFullest())
}
// Output:
// example.com/jmorganca/mistral:latest
// mistral:latest
// mistral:latest
}
func keep[T any](v T) T { return v }

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("/0")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0//0")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0 /0")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("+0/00000")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string(":")

View File

@@ -1,2 +0,0 @@
go test fuzz v1
string("0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91")

View File

@@ -1,89 +0,0 @@
package oweb
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"github.com/ollama/ollama/x/client/ollama"
)
func Missing(field string) error {
return &ollama.Error{
Status: 400,
Code: "missing",
Field: field,
Message: fmt.Sprintf("%s is required", field),
}
}
func Invalid(field, value, format string, args ...any) error {
return &ollama.Error{
Status: 400,
Code: "invalid",
Field: field,
Value: value,
Message: fmt.Sprintf(format, args...),
}
}
// Convenience errors
var (
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
)
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
if err := h(w, r); err != nil {
// TODO: take a slog.Logger
log.Printf("error: %v", err)
var oe *ollama.Error
if !errors.As(err, &oe) {
oe = ErrInternal
}
oe.Status = cmp.Or(oe.Status, 400)
w.WriteHeader(oe.Status)
if err := EncodeJSON(w, oe); err != nil {
log.Printf("error encoding error: %v", err)
}
}
}
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
v, err := DecodeJSON[T](r)
// Handle common JSON syntax errors
var e *json.SyntaxError
if errors.As(err, &e) {
return nil, Invalid(field, "", e.Error())
}
// Handle type errors
var se *json.UnmarshalTypeError
if errors.As(err, &se) {
return nil, Invalid(field, se.Value, "expected %s", se.Type)
}
// Return v and err as they were.
return v, err
}
func DecodeJSON[T any](r io.Reader) (*T, error) {
var v *T
if err := json.NewDecoder(r).Decode(&v); err != nil {
var zero T
return &zero, err
}
return v, nil
}
func EncodeJSON(w io.Writer, v any) error {
return json.NewEncoder(w).Encode(v)
}

View File

@@ -1,46 +0,0 @@
package apitype
import "encoding/json"
type Manifest struct {
Layers []Layer `json:"layers"`
}
type CompletePart struct {
URL string `json:"url"` // contains partNumber and uploadId from server
ETag string `json:"etag"`
}
type Layer struct {
Digest string `json:"digest"`
MediaType string `json:"mediaType"`
Size int64 `json:"size"`
}
type PushRequest struct {
Name string `json:"ref"`
Manifest json.RawMessage `json:"manifest"`
// Parts is a list of upload parts that the client upload in the previous
// push.
CompleteParts []CompletePart `json:"part_uploads"`
}
type Requirement struct {
Digest string `json:"digest"`
Offset int64 `json:"offset"`
Size int64 `json:"Size"`
// URL is the url to PUT the layer to.
//
// Clients must include it as the URL, alond with the ETag in the
// response headers from the PUT request, in the next push request
// in the Uploaded field.
URL string `json:"url"`
}
type PushResponse struct {
// Requirements is a list of digests that the client needs to push before
// repushing the manifest.
Requirements []Requirement `json:"requirements,omitempty"`
}

View File

@@ -1,102 +0,0 @@
package registry
import (
"cmp"
"context"
"encoding/xml"
"errors"
"fmt"
"io"
"net/http"
"strings"
"github.com/ollama/ollama/x/client/ollama"
"github.com/ollama/ollama/x/registry/apitype"
)
type Client struct {
BaseURL string
HTTPClient *http.Client
}
func (c *Client) oclient() *ollama.Client {
return (*ollama.Client)(c)
}
type PushParams struct {
CompleteParts []apitype.CompletePart
}
// Push pushes a manifest to the server.
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
p = cmp.Or(p, &PushParams{})
// TODO(bmizerany): backoff
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
Name: ref,
Manifest: manifest,
CompleteParts: p.CompleteParts,
})
if err != nil {
return nil, err
}
return v.Requirements, nil
}
func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
var zero apitype.CompletePart
if off < 0 {
return zero, errors.New("off must be >0")
}
file := io.NewSectionReader(body, off, n)
req, err := http.NewRequest("PUT", url, file)
if err != nil {
return zero, err
}
req.ContentLength = n
// TODO(bmizerany): take content type param
req.Header.Set("Content-Type", "text/plain")
if n >= 0 {
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
}
res, err := http.DefaultClient.Do(req)
if err != nil {
return zero, err
}
defer res.Body.Close()
if res.StatusCode != 200 {
e := parseS3Error(res)
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
}
etag := strings.Trim(res.Header.Get("ETag"), `"`)
cp := apitype.CompletePart{
URL: url,
ETag: etag,
// TODO(bmizerany): checksum
}
return cp, nil
}
type s3Error struct {
XMLName xml.Name `xml:"Error"`
Code string `xml:"Code"`
Message string `xml:"Message"`
Resource string `xml:"Resource"`
RequestId string `xml:"RequestId"`
}
func (e *s3Error) Error() string {
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
}
// parseS3Error parses an XML error response from S3.
func parseS3Error(res *http.Response) error {
var se *s3Error
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
return err
}
return se
}

View File

@@ -1,256 +0,0 @@
// Package implements an Ollama registry client and server package registry
package registry
import (
"bytes"
"cmp"
"context"
"errors"
"fmt"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/ollama/ollama/x/client/ollama"
"github.com/ollama/ollama/x/model"
"github.com/ollama/ollama/x/oweb"
"github.com/ollama/ollama/x/registry/apitype"
"github.com/ollama/ollama/x/utils/upload"
)
// Defaults
const (
DefaultUploadChunkSize = 50 * 1024 * 1024
)
type Server struct {
UploadChunkSize int64 // default is DefaultUploadChunkSize
S3Client *minio.Client
}
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if err := s.serveHTTP(w, r); err != nil {
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
var e *ollama.Error
if !errors.As(err, &e) {
e = oweb.ErrInternal
}
w.WriteHeader(cmp.Or(e.Status, 400))
if err := oweb.EncodeJSON(w, e); err != nil {
log.Printf("error encoding error: %v", err)
}
}
}
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
switch r.URL.Path {
case "/v1/push":
return s.handlePush(w, r)
case "/v1/pull":
return s.handlePull(w, r)
default:
return oweb.ErrNotFound
}
}
func (s *Server) uploadChunkSize() int64 {
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
}
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
const bucketTODO = "test"
const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
if err != nil {
return err
}
mp := model.ParseName(pr.Name)
if !mp.IsComplete() {
return oweb.Invalid("name", pr.Name, "must be complete")
}
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
if err != nil {
return err
}
mcc := &minio.Core{Client: s.s3()}
// TODO(bmizerany): complete uploads before stats for any with ETag
type completeParts struct {
key string
parts []minio.CompletePart
}
completePartsByUploadID := make(map[string]completeParts)
for _, mcp := range pr.CompleteParts {
// parse the URL
u, err := url.Parse(mcp.URL)
if err != nil {
return err
}
q := u.Query()
// Check if this is a part upload, if not, skip
uploadID := q.Get("uploadId")
if uploadID == "" {
// not a part upload
continue
}
// PartNumber is required
queryPartNumber := q.Get("partNumber")
partNumber, err := strconv.Atoi(queryPartNumber)
if err != nil {
return oweb.Invalid("partNumber", queryPartNumber, "")
}
if partNumber < 1 {
return oweb.Invalid("partNumber", queryPartNumber, "must be >= 1")
}
// ETag is required
if mcp.ETag == "" {
return oweb.Missing("etag")
}
cp := completePartsByUploadID[uploadID]
cp.key = u.Path
cp.parts = append(cp.parts, minio.CompletePart{
PartNumber: partNumber,
ETag: mcp.ETag,
})
completePartsByUploadID[uploadID] = cp
}
for uploadID, cp := range completePartsByUploadID {
var zeroOpts minio.PutObjectOptions
// TODO: gross fix!!!!!!!!!!!!!!!
key := strings.TrimPrefix(cp.key, "/"+bucketTODO+"/")
fmt.Printf("Completing multipart upload %s %s %v\n", bucketTODO, key, cp.parts)
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, key, uploadID, cp.parts, zeroOpts)
if err != nil {
var e minio.ErrorResponse
if errors.As(err, &e) && e.Code == "NoSuchUpload" {
return oweb.Invalid("uploadId", uploadID, "")
}
return err
}
}
var requirements []apitype.Requirement
for _, l := range m.Layers {
// TODO(bmizerany): do in parallel
if l.Size == 0 {
continue
}
// TODO(bmizerany): "global" throttle of rate of transfer
pushed, err := s.statObject(r.Context(), l.Digest)
if err != nil {
println("ERROR:", "statObject", err)
return err
}
if !pushed {
key := path.Join("blobs", l.Digest)
if l.Size < minimumMultipartSize {
// single part upload
fmt.Printf("Presigning single %s %s\n", bucketTODO, key)
signedURL, err := s.s3().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute)
if err != nil {
println("ERROR:", "presign single", err)
return err
}
requirements = append(requirements, apitype.Requirement{
Digest: l.Digest,
Size: l.Size,
URL: signedURL.String(),
})
} else {
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
if err != nil {
return err
}
fmt.Printf("Presigning multi %s %s %s\n", bucketTODO, key, uploadID)
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
const timeToStartUpload = 15 * time.Minute
signedURL, err := s.s3().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
"partNumber": []string{strconv.Itoa(partNumber)},
"uploadId": []string{uploadID},
})
if err != nil {
println("ERROR:", "presign multi", err)
return err
}
requirements = append(requirements, apitype.Requirement{
Digest: l.Digest,
Offset: c.Offset,
Size: c.N,
URL: signedURL.String(),
})
}
}
}
}
if len(requirements) == 0 {
// Commit the manifest
body := bytes.NewReader(pr.Manifest)
path := path.Join("manifests", path.Join(mp.Parts()...))
_, err := s.s3().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
if err != nil {
return err
}
}
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
}
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
// lookup manifest
panic("TODO")
}
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
// HEAD the object
path := path.Join("blobs", digest)
_, err = s.s3().StatObject(ctx, "test", path, minio.StatObjectOptions{})
if err != nil {
if isNoSuchKey(err) {
err = nil
}
return false, err
}
return true, nil
}
func isNoSuchKey(err error) bool {
var e minio.ErrorResponse
return errors.As(err, &e) && e.Code == "NoSuchKey"
}
func (s *Server) s3() *minio.Client {
if s.S3Client != nil {
return s.S3Client
}
s3, err := minio.New("localhost:9000", &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
panic(err)
}
return s3
}

View File

@@ -1,473 +0,0 @@
package registry
import (
"bufio"
"bytes"
"cmp"
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http/httptest"
"net/url"
"os"
"os/exec"
"strconv"
"syscall"
"testing"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/ollama/ollama/x/registry/apitype"
"github.com/ollama/ollama/x/utils/backoff"
"github.com/ollama/ollama/x/utils/upload"
"kr.dev/diff"
)
// const ref = "registry.ollama.ai/x/y:latest+Z"
// const manifest = `{
// "layers": [
// {"digest": "sha256-1", "size": 1},
// {"digest": "sha256-2", "size": 2},
// {"digest": "sha256-3", "size": 3}
// ]
// }`
// ts := newTestServer(t)
// ts.pushNotOK(ref, `{}`, &ollama.Error{
// Status: 400,
// Code: "invalid",
// Message: "name must be fully qualified",
// })
// ts.push(ref, `{
// "layers": [
// {"digest": "sha256-1", "size": 1},
// {"digest": "sha256-2", "size": 2},
// {"digest": "sha256-3", "size": 3}
// ]
// }`)
type tWriter struct {
t *testing.T
}
func (w tWriter) Write(p []byte) (n int, err error) {
w.t.Logf("%s", p)
return len(p), nil
}
func TestPushBasic(t *testing.T) {
const MB = 1024 * 1024
mc := startMinio(t, true)
defer func() {
mcc := &minio.Core{Client: mc}
// fail if there are any incomplete uploads
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
t.Errorf("incomplete: %v", x)
}
}()
const ref = "registry.ollama.ai/x/y:latest+Z"
// Upload two small layers and one large layer that will
// trigger a multipart upload.
manifest := []byte(`{
"layers": [
{"digest": "sha256-1", "size": 1},
{"digest": "sha256-2", "size": 2},
{"digest": "sha256-3", "size": 11000000}
]
}`)
hs := httptest.NewServer(&Server{
S3Client: mc,
UploadChunkSize: 5 * MB,
})
t.Cleanup(hs.Close)
c := &Client{BaseURL: hs.URL}
requirements, err := c.Push(context.Background(), ref, manifest, nil)
if err != nil {
t.Fatal(err)
}
if len(requirements) < 3 {
t.Errorf("expected at least 3 requirements; got %d", len(requirements))
t.Logf("requirements: %v", requirements)
}
var uploaded []apitype.CompletePart
for i, r := range requirements {
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
if err != nil {
t.Fatal(err)
}
uploaded = append(uploaded, cp)
}
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
CompleteParts: uploaded,
})
if err != nil {
t.Fatal(err)
}
if len(requirements) != 0 {
t.Errorf("unexpected requirements: %v", requirements)
}
var paths []string
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
Recursive: true,
})
for k := range keys {
paths = append(paths, k.Key)
}
t.Logf("paths: %v", paths)
diff.Test(t, t.Errorf, paths, []string{
"blobs/sha256-1",
"blobs/sha256-2",
"blobs/sha256-3",
"manifests/registry.ollama.ai/x/y/latest/Z",
})
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
var gotM apitype.Manifest
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
t.Fatal(err)
}
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
Layers: []apitype.Layer{
{Digest: "sha256-1", Size: 1},
{Digest: "sha256-2", Size: 2},
{Digest: "sha256-3", Size: 11000000},
},
})
// checksum the blobs
for i, l := range gotM.Layers {
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
info, err := obj.Stat()
if err != nil {
t.Fatal(err)
}
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
if msg := checkABCs(obj, int(l.Size)); msg != "" {
t.Errorf("[%d] %s", i, msg)
}
}
}
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
// presigning a multipart upload, uploading the parts, and completing the
// upload. It is for future reference and should not be deleted. This flow
// is tricky and if we get it wrong in our server, we can refer back to this
// as a "back to basics" test/reference.
func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
t.Skip("skipping reference test; unskip when needed")
mc := startMinio(t, true)
mcc := &minio.Core{Client: mc}
uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
if err != nil {
t.Fatal(err)
}
var completed []minio.CompletePart
const size int64 = 10 * 1024 * 1024
const chunkSize = 5 * 1024 * 1024
for partNumber, c := range upload.Chunks(size, chunkSize) {
u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
"partNumber": {strconv.Itoa(partNumber)},
"uploadId": {uploadID},
})
if err != nil {
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
}
t.Logf("[partNumber=%d]: %v", partNumber, u)
var body abcReader
cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
if err != nil {
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
}
t.Logf("completed part: %v", cp)
// behave like server here (don't cheat and use partNumber)
// instead get partNumber from the URL
retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
if err != nil {
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
}
completed = append(completed, minio.CompletePart{
PartNumber: retPartNumber,
ETag: cp.ETag,
})
}
defer func() {
// fail if there are any incomplete uploads
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
t.Errorf("incomplete: %v", x)
}
}()
info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
if err != nil {
t.Fatal(err)
}
t.Logf("completed: %v", info)
// Check key in bucket
obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
if err != nil {
t.Fatal(err)
}
defer obj.Close()
h := sha256.New()
if _, err := io.Copy(h, obj); err != nil {
t.Fatal(err)
}
gotSum := h.Sum(nil)
h.Reset()
var body abcReader
if _, err := io.CopyN(h, &body, size); err != nil {
t.Fatal(err)
}
wantSum := h.Sum(nil)
if !bytes.Equal(gotSum, wantSum) {
t.Errorf("got sum = %x; want %x", gotSum, wantSum)
}
}
func availableAddr() string {
l, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic(err)
}
defer l.Close()
return l.Addr().String()
}
// tracing is "experimental" and may be removed in the future, I can't get it to
// work consistently, but I'm leaving it in for now.
func startMinio(t *testing.T, trace bool) *minio.Client {
t.Helper()
// Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
// explicitly setting trace to true.
trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
dir := t.TempDir()
t.Cleanup(func() {
// TODO(bmizerany): trim temp dir based on dates so that
// future runs may be able to inspect results for some time.
})
waitAndMaybeLogError := func(cmd *exec.Cmd) {
if err := cmd.Wait(); err != nil {
var e *exec.ExitError
if errors.As(err, &e) {
if e.Exited() {
return
}
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
} else {
if errors.Is(err, context.Canceled) {
return
}
t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
}
}
}
// Cancel must be called first so do wait to add to Cleanup
// stack as last cleanup.
ctx, cancel := context.WithCancel(context.Background())
deadline, ok := t.Deadline()
if ok {
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
}
t.Logf(">> minio: minio server %s", dir)
addr := availableAddr()
cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
cmd.Env = os.Environ()
cmd.WaitDelay = 3 * time.Second
cmd.Cancel = func() error {
return cmd.Process.Signal(syscall.SIGQUIT)
}
if err := cmd.Start(); err != nil {
t.Fatalf("startMinio: %v", err)
}
t.Cleanup(func() {
cancel()
waitAndMaybeLogError(cmd)
})
mc, err := minio.New(addr, &minio.Options{
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
Secure: false,
})
if err != nil {
t.Fatalf("startMinio: %v", err)
}
// wait for server to start with exponential backoff
for _, err := range backoff.Upto(ctx, 1*time.Second) {
if err != nil {
t.Fatalf("startMinio: %v", err)
}
// try list buckets to see if server is up
if _, err := mc.ListBuckets(ctx); err == nil {
break
}
t.Logf("startMinio: server is offline; retrying")
}
if trace {
cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
cmd.Env = append(os.Environ(),
"MC_HOST_test=http://minioadmin:minioadmin@"+addr,
)
cmd.WaitDelay = 3 * time.Second
cmd.Cancel = func() error {
return cmd.Process.Signal(syscall.SIGQUIT)
}
stdout, err := cmd.StdoutPipe()
if err != nil {
t.Fatalf("startMinio: %v", err)
}
if err := cmd.Start(); err != nil {
t.Fatalf("startMinio: %v", err)
}
doneLogging := make(chan struct{})
sc := bufio.NewScanner(stdout)
go func() {
defer close(doneLogging)
// Scan lines until the process exits.
for sc.Scan() {
t.Logf("startMinio: mc trace: %s", sc.Text())
}
_ = sc.Err() // ignore (not important)
}()
t.Cleanup(func() {
cancel()
waitAndMaybeLogError(cmd)
// Make sure we do not log after test exists to
// avoid panic.
<-doneLogging
})
}
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
t.Fatalf("startMinio: %v", err)
}
return mc
}
// contextForTest returns a context that is canceled when the test deadline,
// if any, is reached. The returned doneLogging function should be called
// after all Log/Error/Fatalf calls are done before the test returns.
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
done := make(chan struct{})
deadline, ok := t.Deadline()
if !ok {
return context.Background(), func() {}
}
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
t.Cleanup(func() {
cancel()
<-done
})
return ctx, func() { close(done) }
}
// abcReader repeats the string s infinitely.
type abcReader struct {
pos int
}
const theABCs = "abcdefghijklmnopqrstuvwxyz"
func (r *abcReader) Read(p []byte) (n int, err error) {
for i := range p {
p[i] = theABCs[r.pos]
r.pos++
if r.pos == len(theABCs) {
r.pos = 0
}
}
return len(p), nil
}
func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
for i := range p {
p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
}
return len(p), nil
}
func checkABCs(r io.Reader, size int) (reason string) {
h := sha256.New()
n, err := io.CopyN(h, &abcReader{}, int64(size))
if err != nil {
return err.Error()
}
if n != int64(size) {
panic("short read; should not happen")
}
want := h.Sum(nil)
h = sha256.New()
n, err = io.Copy(h, r)
if err != nil {
return err.Error()
}
if n != int64(size) {
return fmt.Sprintf("got len(r) = %d; want %d", n, size)
}
got := h.Sum(nil)
if !bytes.Equal(got, want) {
return fmt.Sprintf("got sum = %x; want %x", got, want)
}
return ""
}

View File

@@ -1,4 +0,0 @@
package empty
// Message is a placeholder type used when encoding json messages.
type Message struct{}

View File

@@ -1,15 +0,0 @@
// Copyright (c) Tailscale Inc & AUTHORS
// SPDX-License-Identifier: BSD-3-Clause
// Package structs contains the Incomparable type.
package structs
// Incomparable is a zero-width incomparable type. If added as the
// first field in a struct, it marks that struct as not comparable
// (can't do == or be a map key) and usually doesn't add any width to
// the struct (unless the struct has only small fields).
//
// By making a struct incomparable, you can prevent misuse (prevent
// people from using ==), but also you can shrink generated binaries,
// as the compiler can omit equality funcs from the binary.
type Incomparable [0]func()

View File

@@ -1,12 +0,0 @@
package they
import (
"net/http"
"strings"
)
// Want returns true if the request method is method and the request path
// starts with pathPrefix.
func Want(r *http.Request, method string, pathPrefix string) bool {
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
}

View File

@@ -1,58 +0,0 @@
package backoff
import (
"context"
"errors"
"iter"
"math/rand"
"time"
)
// Errors
var (
// ErrMaxAttempts is not used by backoff but is available for use by
// callers that want to signal that a maximum number of retries has
// been exceeded. This should eliminate the need for callers to invent
// their own error.
ErrMaxAttempts = errors.New("max retries exceeded")
)
// Upto implements a backoff strategy that yields nil errors until the
// context is canceled, the maxRetries is exceeded, or yield returns false.
//
// The backoff strategy is a simple exponential backoff with a maximum
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
// the current backoff, in order to prevent accidental "thundering herd"
// problems.
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
var n int
return func(yield func(int, error) bool) {
for {
if ctx.Err() != nil {
yield(n, ctx.Err())
return
}
n++
// n^2 backoff timer is a little smoother than the
// common choice of 2^n.
d := time.Duration(n*n) * 10 * time.Millisecond
if d > maxBackoff {
d = maxBackoff
}
// Randomize the delay between 0.5-1.5 x msec, in order
// to prevent accidental "thundering herd" problems.
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
t := time.NewTimer(d)
select {
case <-ctx.Done():
t.Stop()
case <-t.C:
if !yield(n, nil) {
return
}
}
}
}
}

View File

@@ -1,29 +0,0 @@
package upload
import (
"iter"
"golang.org/x/exp/constraints"
)
type Chunk[I constraints.Integer] struct {
Offset I
N I
}
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
// not a multiple of chunkSize.
//
// The first part number is 1 and increases monotonically.
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
return func(yield func(int, Chunk[I]) bool) {
var n int
for off := I(0); off < size; off += chunkSize {
n++
if !yield(n, Chunk[I]{off, min(chunkSize, size-off)}) {
return
}
}
}
}

View File

@@ -1,44 +0,0 @@
package upload
import (
"testing"
"kr.dev/diff"
)
func TestChunks(t *testing.T) {
const size = 101
const chunkSize = 10
var got []Chunk[int]
var lastN int
for n, c := range Chunks(size, chunkSize) {
if n != lastN+1 {
t.Errorf("n = %d; want %d", n, lastN+1)
}
got = append(got, c)
lastN = n
}
want := []Chunk[int]{
{0, 10},
{10, 10},
{20, 10},
{30, 10},
{40, 10},
{50, 10},
{60, 10},
{70, 10},
{80, 10},
{90, 10},
{100, 1},
}
diff.Test(t, t.Errorf, got, want)
}
func TestChunksBreak(t *testing.T) {
for _, _ = range Chunks(1, 1) {
return
}
t.Fatal("expected break")
}