Compare commits
135 Commits
bmizerany/
...
bmizerany/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2cf25caf7 | ||
|
|
aca112a308 | ||
|
|
1389e6926a | ||
|
|
a292cde2f3 | ||
|
|
ab9e476551 | ||
|
|
d721228a2b | ||
|
|
d3c6400487 | ||
|
|
06c21f00eb | ||
|
|
8b62eaf059 | ||
|
|
6c1c0f9f1a | ||
|
|
6a4b3c3823 | ||
|
|
38e7ddb39d | ||
|
|
f595dea189 | ||
|
|
be7fe0d6d8 | ||
|
|
98dbc1202a | ||
|
|
bdff89bc4c | ||
|
|
2100129e83 | ||
|
|
4eb7acf84b | ||
|
|
ff68227ca1 | ||
|
|
d2aef85dda | ||
|
|
1407fd3d4a | ||
|
|
07f27312fa | ||
|
|
6ba495d4a3 | ||
|
|
bd446a72cc | ||
|
|
5615f60bb0 | ||
|
|
348378ef56 | ||
|
|
81e8c49ac2 | ||
|
|
0172726a58 | ||
|
|
c84f9b07b0 | ||
|
|
712eaa4b09 | ||
|
|
2f241692bd | ||
|
|
d35a6a577f | ||
|
|
14a6f85e9e | ||
|
|
45d8d22785 | ||
|
|
e201627c63 | ||
|
|
5e76860c47 | ||
|
|
fb0782b7a9 | ||
|
|
0bee38f6b5 | ||
|
|
b24f1ad587 | ||
|
|
0bea2b8916 | ||
|
|
e4d65d5aef | ||
|
|
6e464ebef8 | ||
|
|
f2c17682b0 | ||
|
|
f0e6c563e2 | ||
|
|
a187851900 | ||
|
|
2e600aa398 | ||
|
|
0c78e6c23d | ||
|
|
c0d4f55f3e | ||
|
|
d67ff60643 | ||
|
|
3cb07b3ac9 | ||
|
|
fc595d89bf | ||
|
|
e1de015fbc | ||
|
|
1c04951cac | ||
|
|
95559adee3 | ||
|
|
9821ca28e8 | ||
|
|
c5768ceffe | ||
|
|
0b65220936 | ||
|
|
a5f05236f7 | ||
|
|
3cf109ec59 | ||
|
|
7c7f56a7fb | ||
|
|
bf8e0c09c9 | ||
|
|
a6b8bdf938 | ||
|
|
f51197a814 | ||
|
|
713e2feacf | ||
|
|
7f1faf6c12 | ||
|
|
ad6f020bd8 | ||
|
|
e052bd8c0f | ||
|
|
aef832b298 | ||
|
|
805a92e6f2 | ||
|
|
0c46151700 | ||
|
|
bfe89d6fa0 | ||
|
|
92b7e40fde | ||
|
|
d510a90214 | ||
|
|
a4fd06d603 | ||
|
|
cfe0bb6bb6 | ||
|
|
6aa9795c4f | ||
|
|
5041000a28 | ||
|
|
7cd939690a | ||
|
|
42cda9dd46 | ||
|
|
6917865bf3 | ||
|
|
2633fcb149 | ||
|
|
58de2b8d4a | ||
|
|
de72688b35 | ||
|
|
cbb367b1df | ||
|
|
18160475c4 | ||
|
|
31e9b3dd15 | ||
|
|
e28cfdf813 | ||
|
|
2751c26da7 | ||
|
|
45ca3c80e8 | ||
|
|
d85fbd0e99 | ||
|
|
acf1cb1dc4 | ||
|
|
c787b8b2dd | ||
|
|
9f2d8d2117 | ||
|
|
d42c3f6be1 | ||
|
|
4ea3e9efa6 | ||
|
|
2e1ea6ecaa | ||
|
|
6d2da77ce2 | ||
|
|
def4d902bf | ||
|
|
76a202c04e | ||
|
|
f7cfe946dc | ||
|
|
005b6373e2 | ||
|
|
d54e0fb3b2 | ||
|
|
bdd05e0ae0 | ||
|
|
1a346640db | ||
|
|
f5883070f8 | ||
|
|
adc23d5f96 | ||
|
|
a10a11b9d3 | ||
|
|
94befe366a | ||
|
|
c95f97689b | ||
|
|
618eb5b909 | ||
|
|
eb75418be9 | ||
|
|
9959da05de | ||
|
|
aff7970628 | ||
|
|
628f1feb36 | ||
|
|
ce3125afd5 | ||
|
|
f488652ba7 | ||
|
|
2318ed2919 | ||
|
|
b1b8be33d9 | ||
|
|
876f7eab81 | ||
|
|
7cfc8a0838 | ||
|
|
fd411b3cf6 | ||
|
|
04f38cf3f4 | ||
|
|
c0eddb10fd | ||
|
|
60ef0e6b4a | ||
|
|
48c60c01e2 | ||
|
|
eb2c442a01 | ||
|
|
c87fe7df48 | ||
|
|
5182a1dfb1 | ||
|
|
a32e7857b2 | ||
|
|
6acc205de0 | ||
|
|
f6e02d4bc7 | ||
|
|
e1d457c73e | ||
|
|
cd5df121a5 | ||
|
|
112ffed189 | ||
|
|
c49947dcf5 |
67
.github/workflows/release.yaml
vendored
67
.github/workflows/release.yaml
vendored
@@ -8,7 +8,7 @@ on:
|
||||
jobs:
|
||||
# Full build of the Mac assets
|
||||
build-darwin:
|
||||
runs-on: macos-12
|
||||
runs-on: macos-latest
|
||||
environment: release
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
@@ -30,7 +30,7 @@ jobs:
|
||||
security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k password build.keychain
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: Build Darwin
|
||||
env:
|
||||
@@ -38,11 +38,9 @@ jobs:
|
||||
APPLE_PASSWORD: ${{ secrets.APPLE_PASSWORD }}
|
||||
APPLE_TEAM_ID: ${{ vars.APPLE_TEAM_ID }}
|
||||
APPLE_ID: ${{ vars.APPLE_ID }}
|
||||
SDKROOT: /Applications/Xcode_13.4.1.app/Contents/Developer/Platforms/MacOSX.platform/Developer/SDKs/MacOSX.sdk
|
||||
DEVELOPER_DIR: /Applications/Xcode_13.4.1.app/Contents/Developer
|
||||
run: |
|
||||
./scripts/build_darwin.sh
|
||||
|
||||
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: dist-darwin
|
||||
@@ -50,6 +48,7 @@ jobs:
|
||||
dist/*arwin*
|
||||
!dist/*-cov
|
||||
|
||||
|
||||
# Windows builds take a long time to both install the dependencies and build, so parallelize
|
||||
# CPU generation step
|
||||
generate-windows-cpu:
|
||||
@@ -86,7 +85,7 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -100,9 +99,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -139,9 +136,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
- name: "Install ROCm"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -149,7 +146,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: 'Verify ROCm'
|
||||
- name: "Verify ROCm"
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -163,7 +160,7 @@ jobs:
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
- name: 'gather rocm dependencies'
|
||||
- name: "gather rocm dependencies"
|
||||
run: |
|
||||
$HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
md "dist\deps\bin\rocblas\library"
|
||||
@@ -173,7 +170,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/build/**/bin/*
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-rocm-deps
|
||||
@@ -214,9 +211,9 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
- name: "Install CUDA"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
@@ -230,7 +227,7 @@ jobs:
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
- name: "Verify CUDA"
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
@@ -243,7 +240,7 @@ jobs:
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
- name: 'gather cuda dependencies'
|
||||
- name: "gather cuda dependencies"
|
||||
run: |
|
||||
$NVIDIA_DIR=(resolve-path 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\*\bin\')[0]
|
||||
md "dist\deps"
|
||||
@@ -253,7 +250,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/build/**/bin/*
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -300,17 +297,17 @@ jobs:
|
||||
write-host "plugin installed"
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cpu
|
||||
path: llm/build
|
||||
path: llm/llama.cpp/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-cuda
|
||||
path: llm/build
|
||||
path: llm/llama.cpp/build
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: windows-cuda-deps
|
||||
@@ -322,8 +319,8 @@ jobs:
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: generate-windows-rocm
|
||||
path: llm/build
|
||||
- run: dir llm/build
|
||||
path: llm/llama.cpp/build
|
||||
- run: dir llm/llama.cpp/build
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
@@ -339,14 +336,14 @@ jobs:
|
||||
name: dist-windows
|
||||
path: dist/*.exe
|
||||
|
||||
# Linux x86 assets built using the container based build
|
||||
# Linux x86 assets built using the container based build
|
||||
build-linux-amd64:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
BUILD_ARCH: amd64
|
||||
PUSH: '1'
|
||||
PUSH: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -376,9 +373,9 @@ jobs:
|
||||
environment: release
|
||||
runs-on: linux-arm64
|
||||
env:
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: '1'
|
||||
OLLAMA_SKIP_MANIFEST_CREATE: "1"
|
||||
BUILD_ARCH: arm64
|
||||
PUSH: '1'
|
||||
PUSH: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -386,7 +383,7 @@ jobs:
|
||||
- name: Set Version
|
||||
shell: bash
|
||||
run: echo "VERSION=${GITHUB_REF_NAME#v}" >> $GITHUB_ENV
|
||||
- name: 'Install Docker'
|
||||
- name: "Install Docker"
|
||||
run: |
|
||||
# Add Docker's official GPG key:
|
||||
env
|
||||
@@ -423,7 +420,7 @@ jobs:
|
||||
!dist/*-cov
|
||||
|
||||
# Aggregate all the assets and ship a release
|
||||
release:
|
||||
release:
|
||||
needs:
|
||||
- build-darwin
|
||||
- build-windows
|
||||
@@ -434,8 +431,8 @@ jobs:
|
||||
permissions:
|
||||
contents: write
|
||||
env:
|
||||
OLLAMA_SKIP_IMAGE_BUILD: '1'
|
||||
PUSH: '1'
|
||||
OLLAMA_SKIP_IMAGE_BUILD: "1"
|
||||
PUSH: "1"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set Version
|
||||
@@ -463,11 +460,11 @@ jobs:
|
||||
with:
|
||||
name: ${{ env.RELEASE_VERSION }}
|
||||
allowUpdates: true
|
||||
artifacts: 'dist/*'
|
||||
artifacts: "dist/*"
|
||||
draft: true
|
||||
prerelease: true
|
||||
omitBodyDuringUpdate: true
|
||||
generateReleaseNotes: true
|
||||
omitDraftDuringUpdate: true
|
||||
omitPrereleaseDuringUpdate: true
|
||||
replacesArtifacts: true
|
||||
replacesArtifacts: true
|
||||
60
.github/workflows/test.yaml
vendored
60
.github/workflows/test.yaml
vendored
@@ -5,6 +5,7 @@ on:
|
||||
paths:
|
||||
- '**/*'
|
||||
- '!docs/**'
|
||||
- '!examples/**'
|
||||
- '!README.md'
|
||||
|
||||
jobs:
|
||||
@@ -50,7 +51,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -63,10 +64,10 @@ jobs:
|
||||
echo $env:PATH
|
||||
go generate -x ./...
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
name: 'Windows Go Generate'
|
||||
name: "Windows Go Generate"
|
||||
- run: go generate -x ./...
|
||||
if: ${{ ! startsWith(matrix.os, 'windows-') }}
|
||||
name: 'Unix Go Generate'
|
||||
name: "Unix Go Generate"
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
@@ -92,7 +93,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -123,7 +124,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v4
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
@@ -134,7 +135,7 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/build/**/bin/*
|
||||
path: llm/build/**/lib/*
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
@@ -145,9 +146,9 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: 'Install ROCm'
|
||||
- name: "Install ROCm"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
@@ -155,7 +156,7 @@ jobs:
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: 'Verify ROCm'
|
||||
- name: "Verify ROCm"
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
@@ -182,9 +183,9 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: 'Install CUDA'
|
||||
- name: "Install CUDA"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
@@ -198,7 +199,7 @@ jobs:
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: 'Verify CUDA'
|
||||
- name: "Verify CUDA"
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
@@ -215,6 +216,7 @@ jobs:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -237,7 +239,7 @@ jobs:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: false
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
@@ -246,18 +248,18 @@ jobs:
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm/build/linux/$ARCH/stub/bin/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: bash
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s
|
||||
@@ -275,14 +277,14 @@ jobs:
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: '1'
|
||||
OLLAMA_CPU_TARGET: 'static'
|
||||
OLLAMA_CPU_TARGET: "static"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
submodules: recursive
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get
|
||||
- run: |
|
||||
@@ -292,18 +294,18 @@ jobs:
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin
|
||||
touch llm/build/linux/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm//build/linux/$ARCH/stub/bin/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin
|
||||
touch llm/build/darwin/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/bin
|
||||
touch llm/build/windows/$ARCH/stub/bin/ollama_llama_server
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
shell: bash
|
||||
- run: go generate ./...
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
|
||||
@@ -42,7 +42,7 @@ ARG CGO_CFLAGS
|
||||
ARG AMDGPU_TARGETS
|
||||
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
|
||||
RUN mkdir /tmp/scratch && \
|
||||
for dep in $(zcat /go/src/github.com/ollama/ollama/llm/build/linux/x86_64/rocm*/bin/deps.txt.gz) ; do \
|
||||
for dep in $(cat /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/x86_64/rocm*/lib/deps.txt) ; do \
|
||||
cp ${dep} /tmp/scratch/ || exit 1 ; \
|
||||
done && \
|
||||
(cd /opt/rocm/lib && tar cf - rocblas/library) | (cd /tmp/scratch/ && tar xf - ) && \
|
||||
|
||||
@@ -64,7 +64,6 @@ Here are some example models that can be downloaded:
|
||||
| LLaVA | 7B | 4.5GB | `ollama run llava` |
|
||||
| Gemma | 2B | 1.4GB | `ollama run gemma:2b` |
|
||||
| Gemma | 7B | 4.8GB | `ollama run gemma:7b` |
|
||||
| Solar | 10.7B | 6.1GB | `ollama run solar` |
|
||||
|
||||
> Note: You should have at least 8 GB of RAM available to run the 7B models, 16 GB to run the 13B models, and 32 GB to run the 33B models.
|
||||
|
||||
@@ -293,7 +292,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
- [RAGFlow: Open-source Retrieval-Augmented Generation engine based on deep document understanding](https://github.com/infiniflow/ragflow)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -317,7 +315,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Database
|
||||
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md) (Connects Ollama models with nearly 200 data platforms and apps)
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
|
||||
### Package managers
|
||||
|
||||
@@ -1,9 +1,3 @@
|
||||
// Package api implements the client-side API for code wishing to interact
|
||||
// with the ollama service. The methods of the [Client] type correspond to
|
||||
// the ollama REST API as described in https://github.com/ollama/ollama/blob/main/docs/api.md
|
||||
//
|
||||
// The ollama command-line client itself uses this package to interact with
|
||||
// the backend service.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -11,6 +5,7 @@ import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -24,8 +19,6 @@ import (
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
|
||||
// Client encapsulates client state for interacting with the ollama
|
||||
// service. Use [ClientFromEnvironment] to create new Clients.
|
||||
type Client struct {
|
||||
base *url.URL
|
||||
http *http.Client
|
||||
@@ -47,15 +40,6 @@ func checkError(resp *http.Response, body []byte) error {
|
||||
return apiError
|
||||
}
|
||||
|
||||
// ClientFromEnvironment creates a new [Client] using configuration from the
|
||||
// environment variable OLLAMA_HOST, which points to the network host and
|
||||
// port on which the ollama service is listenting. The format of this variable
|
||||
// is:
|
||||
//
|
||||
// <scheme>://<host>:<port>
|
||||
//
|
||||
// If the variable is not specified, a default ollama host and port will be
|
||||
// used.
|
||||
func ClientFromEnvironment() (*Client, error) {
|
||||
defaultPort := "11434"
|
||||
|
||||
@@ -207,14 +191,8 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateResponseFunc is a function that [Client.Generate] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Generate] will stop generating and return this error.
|
||||
type GenerateResponseFunc func(GenerateResponse) error
|
||||
|
||||
// Generate generates a response for a given prompt. The req parameter should
|
||||
// be populated with prompt details. fn is called for each response (there may
|
||||
// be multiple responses, e.g. in case streaming is enabled).
|
||||
func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn GenerateResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/generate", req, func(bts []byte) error {
|
||||
var resp GenerateResponse
|
||||
@@ -226,15 +204,8 @@ func (c *Client) Generate(ctx context.Context, req *GenerateRequest, fn Generate
|
||||
})
|
||||
}
|
||||
|
||||
// ChatResponseFunc is a function that [Client.Chat] invokes every time
|
||||
// a response is received from the service. If this function returns an error,
|
||||
// [Client.Chat] will stop generating and return this error.
|
||||
type ChatResponseFunc func(ChatResponse) error
|
||||
|
||||
// Chat generates the next message in a chat. [ChatRequest] may contain a
|
||||
// sequence of messages which can be used to maintain chat history with a model.
|
||||
// fn is called for each response (there may be multiple responses, e.g. if case
|
||||
// streaming is enabled).
|
||||
func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/chat", req, func(bts []byte) error {
|
||||
var resp ChatResponse
|
||||
@@ -246,14 +217,8 @@ func (c *Client) Chat(ctx context.Context, req *ChatRequest, fn ChatResponseFunc
|
||||
})
|
||||
}
|
||||
|
||||
// PullProgressFunc is a function that [Client.Pull] invokes every time there
|
||||
// is progress with a "pull" request sent to the service. If this function
|
||||
// returns an error, [Client.Pull] will stop the process and return this error.
|
||||
type PullProgressFunc func(ProgressResponse) error
|
||||
|
||||
// Pull downloads a model from the ollama library. fn is called each time
|
||||
// progress is made on the request and can be used to display a progress bar,
|
||||
// etc.
|
||||
func (c *Client) Pull(ctx context.Context, req *PullRequest, fn PullProgressFunc) error {
|
||||
return c.stream(ctx, http.MethodPost, "/api/pull", req, func(bts []byte) error {
|
||||
var resp ProgressResponse
|
||||
@@ -336,7 +301,18 @@ func (c *Client) Embeddings(ctx context.Context, req *EmbeddingRequest) (*Embedd
|
||||
}
|
||||
|
||||
func (c *Client) CreateBlob(ctx context.Context, digest string, r io.Reader) error {
|
||||
return c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil)
|
||||
if err := c.do(ctx, http.MethodHead, fmt.Sprintf("/api/blobs/%s", digest), nil, nil); err != nil {
|
||||
var statusError StatusError
|
||||
if !errors.As(err, &statusError) || statusError.StatusCode != http.StatusNotFound {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.do(ctx, http.MethodPost, fmt.Sprintf("/api/blobs/%s", digest), r, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Version(ctx context.Context) (string, error) {
|
||||
|
||||
110
api/types.go
110
api/types.go
@@ -33,46 +33,18 @@ func (e StatusError) Error() string {
|
||||
|
||||
type ImageData []byte
|
||||
|
||||
// GenerateRequest describes a request sent by [Client.Generate]. While you
|
||||
// have to specify the Model and Prompt fields, all the other fields have
|
||||
// reasonable defaults for basic uses.
|
||||
type GenerateRequest struct {
|
||||
// Model is the model name; it should be a name familiar to Ollama from
|
||||
// the library at https://ollama.com/library
|
||||
Model string `json:"model"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
System string `json:"system"`
|
||||
Template string `json:"template"`
|
||||
Context []int `json:"context,omitempty"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
Format string `json:"format"`
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Prompt is the textual prompt to send to the model.
|
||||
Prompt string `json:"prompt"`
|
||||
|
||||
// System overrides the model's default system message/prompt.
|
||||
System string `json:"system"`
|
||||
|
||||
// Template overrides the model's default prompt template.
|
||||
Template string `json:"template"`
|
||||
|
||||
// Context is the context parameter returned from a previous call to
|
||||
// Generate call. It can be used to keep a short conversational memory.
|
||||
Context []int `json:"context,omitempty"`
|
||||
|
||||
// Stream specifies whether the response is streaming; it is true by default.
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Raw set to true means that no formatting will be applied to the prompt.
|
||||
Raw bool `json:"raw,omitempty"`
|
||||
|
||||
// Format specifies the format to return a response in.
|
||||
Format string `json:"format"`
|
||||
|
||||
// KeepAlive controls how long the model will stay loaded in memory following
|
||||
// this request.
|
||||
KeepAlive *Duration `json:"keep_alive,omitempty"`
|
||||
|
||||
// Images is an optional list of base64-encoded images accompanying this
|
||||
// request, for multimodal models.
|
||||
Images []ImageData `json:"images,omitempty"`
|
||||
|
||||
// Options lists model-specific options. For example, temperature can be
|
||||
// set through this field, if the model supports it.
|
||||
Options map[string]interface{} `json:"options"`
|
||||
}
|
||||
|
||||
@@ -137,24 +109,21 @@ type Options struct {
|
||||
|
||||
// Runner options which must be set when the model is loaded into memory
|
||||
type Runner struct {
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
|
||||
// Unused: RopeFrequencyBase is ignored. Instead the value in the model will be used
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
// Unused: RopeFrequencyScale is ignored. Instead the value in the model will be used
|
||||
UseNUMA bool `json:"numa,omitempty"`
|
||||
NumCtx int `json:"num_ctx,omitempty"`
|
||||
NumBatch int `json:"num_batch,omitempty"`
|
||||
NumGQA int `json:"num_gqa,omitempty"`
|
||||
NumGPU int `json:"num_gpu,omitempty"`
|
||||
MainGPU int `json:"main_gpu,omitempty"`
|
||||
LowVRAM bool `json:"low_vram,omitempty"`
|
||||
F16KV bool `json:"f16_kv,omitempty"`
|
||||
LogitsAll bool `json:"logits_all,omitempty"`
|
||||
VocabOnly bool `json:"vocab_only,omitempty"`
|
||||
UseMMap bool `json:"use_mmap,omitempty"`
|
||||
UseMLock bool `json:"use_mlock,omitempty"`
|
||||
RopeFrequencyBase float32 `json:"rope_frequency_base,omitempty"`
|
||||
RopeFrequencyScale float32 `json:"rope_frequency_scale,omitempty"`
|
||||
NumThread int `json:"num_thread,omitempty"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
@@ -170,11 +139,10 @@ type EmbeddingResponse struct {
|
||||
}
|
||||
|
||||
type CreateRequest struct {
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
Quantization string `json:"quantization,omitempty"`
|
||||
Model string `json:"model"`
|
||||
Path string `json:"path"`
|
||||
Modelfile string `json:"modelfile"`
|
||||
Stream *bool `json:"stream,omitempty"`
|
||||
|
||||
// Name is deprecated, see Model
|
||||
Name string `json:"name"`
|
||||
@@ -414,16 +382,18 @@ func DefaultOptions() Options {
|
||||
|
||||
Runner: Runner{
|
||||
// options set when the model is loaded
|
||||
NumCtx: 2048,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
NumCtx: 2048,
|
||||
RopeFrequencyBase: 10000.0,
|
||||
RopeFrequencyScale: 1.0,
|
||||
NumBatch: 512,
|
||||
NumGPU: -1, // -1 here indicates that NumGPU should be set dynamically
|
||||
NumGQA: 1,
|
||||
NumThread: 0, // let the runtime decide
|
||||
LowVRAM: false,
|
||||
F16KV: true,
|
||||
UseMLock: false,
|
||||
UseMMap: true,
|
||||
UseNUMA: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -86,29 +87,19 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
// Re-wire context done behavior to attempt a graceful shutdown of the server
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
err := terminate(cmd)
|
||||
if err != nil {
|
||||
slog.Warn("error trying to gracefully terminate server", "err", err)
|
||||
return cmd.Process.Kill()
|
||||
}
|
||||
|
||||
cmd.Process.Signal(os.Interrupt) //nolint:errcheck
|
||||
tick := time.NewTicker(10 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
exited, err := isProcessExited(cmd.Process.Pid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if exited {
|
||||
return nil
|
||||
// OS agnostic "is it still running"
|
||||
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
|
||||
return cmd.Process.Kill()
|
||||
cmd.Process.Kill() //nolint:errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,35 +4,9 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, cmd string) *exec.Cmd {
|
||||
return exec.CommandContext(ctx, cmd, "serve")
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
return cmd.Process.Signal(os.Interrupt)
|
||||
}
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
proc, err := os.FindProcess(pid)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to find process: %v", err)
|
||||
}
|
||||
|
||||
err = proc.Signal(syscall.Signal(0))
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrProcessDone) || errors.Is(err, syscall.ESRCH) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("error signaling process: %v", err)
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
@@ -2,88 +2,12 @@ package lifecycle
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"syscall"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func getCmd(ctx context.Context, exePath string) *exec.Cmd {
|
||||
cmd := exec.CommandContext(ctx, exePath, "serve")
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
HideWindow: true,
|
||||
CreationFlags: windows.CREATE_NEW_PROCESS_GROUP,
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{HideWindow: true, CreationFlags: 0x08000000}
|
||||
return cmd
|
||||
}
|
||||
|
||||
func terminate(cmd *exec.Cmd) error {
|
||||
dll, err := windows.LoadDLL("kernel32.dll")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer dll.Release() // nolint: errcheck
|
||||
|
||||
pid := cmd.Process.Pid
|
||||
|
||||
f, err := dll.FindProc("AttachConsole")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err := f.Call(uintptr(pid))
|
||||
if r1 == 0 && err != syscall.ERROR_ACCESS_DENIED {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("SetConsoleCtrlHandler")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(0, 1)
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
f, err = dll.FindProc("GenerateConsoleCtrlEvent")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_BREAK_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
r1, _, err = f.Call(windows.CTRL_C_EVENT, uintptr(pid))
|
||||
if r1 == 0 {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const STILL_ACTIVE = 259
|
||||
|
||||
func isProcessExited(pid int) (bool, error) {
|
||||
hProcess, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION, false, uint32(pid))
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to open process: %v", err)
|
||||
}
|
||||
defer windows.CloseHandle(hProcess) // nolint: errcheck
|
||||
|
||||
var exitCode uint32
|
||||
err = windows.GetExitCodeProcess(hProcess, &exitCode)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to get exit code: %v", err)
|
||||
}
|
||||
|
||||
if exitCode == STILL_ACTIVE {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return true, nil
|
||||
}
|
||||
|
||||
@@ -24,5 +24,10 @@ func NewTray() (commontray.OllamaTray, error) {
|
||||
return nil, fmt.Errorf("failed to load icon %s: %w", iconName, err)
|
||||
}
|
||||
|
||||
return InitPlatformTray(icon, updateIcon)
|
||||
tray, err := InitPlatformTray(icon, updateIcon)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tray, nil
|
||||
}
|
||||
|
||||
56
cmd/cmd.go
56
cmd/cmd.go
@@ -105,48 +105,24 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
|
||||
zf := zip.NewWriter(tf)
|
||||
|
||||
files := []string{}
|
||||
|
||||
tfiles, err := filepath.Glob(filepath.Join(path, "pytorch_model-*.bin"))
|
||||
files, err := filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
||||
if err != nil {
|
||||
return err
|
||||
} else if len(tfiles) == 0 {
|
||||
tfiles, err = filepath.Glob(filepath.Join(path, "model-*.safetensors"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
files = append(files, tfiles...)
|
||||
|
||||
if len(files) == 0 {
|
||||
return fmt.Errorf("no models were found in '%s'", path)
|
||||
return fmt.Errorf("no safetensors files were found in '%s'", path)
|
||||
}
|
||||
|
||||
// add the safetensor/torch config file + tokenizer
|
||||
// add the safetensor config file + tokenizer
|
||||
files = append(files, filepath.Join(path, "config.json"))
|
||||
files = append(files, filepath.Join(path, "params.json"))
|
||||
files = append(files, filepath.Join(path, "added_tokens.json"))
|
||||
files = append(files, filepath.Join(path, "tokenizer.model"))
|
||||
|
||||
for _, fn := range files {
|
||||
f, err := os.Open(fn)
|
||||
|
||||
// just skip whatever files aren't there
|
||||
if os.IsNotExist(err) {
|
||||
if strings.HasSuffix(fn, "tokenizer.model") {
|
||||
// try the parent dir before giving up
|
||||
parentDir := filepath.Dir(path)
|
||||
newFn := filepath.Join(parentDir, "tokenizer.model")
|
||||
f, err = os.Open(newFn)
|
||||
if os.IsNotExist(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
if os.IsNotExist(err) && strings.HasSuffix(fn, "added_tokens.json") {
|
||||
continue
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -218,9 +194,7 @@ func CreateHandler(cmd *cobra.Command, args []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
quantization, _ := cmd.Flags().GetString("quantization")
|
||||
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile), Quantization: quantization}
|
||||
request := api.CreateRequest{Name: args[0], Modelfile: string(modelfile)}
|
||||
if err := client.Create(cmd.Context(), &request, fn); err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -252,6 +226,14 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if os.Getenv("OLLAMA_MODELS") != "" {
|
||||
return errors.New("OLLAMA_MODELS must only be set for 'ollama serve'")
|
||||
}
|
||||
|
||||
if err := checkServerHeartbeat(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -961,7 +943,6 @@ func NewCLI() *cobra.Command {
|
||||
}
|
||||
|
||||
createCmd.Flags().StringP("file", "f", "Modelfile", "Name of the Modelfile (default \"Modelfile\")")
|
||||
createCmd.Flags().StringP("quantization", "q", "", "Quantization level.")
|
||||
|
||||
showCmd := &cobra.Command{
|
||||
Use: "show MODEL",
|
||||
@@ -978,11 +959,10 @@ func NewCLI() *cobra.Command {
|
||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: RunHandler,
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
|
||||
@@ -1,16 +1,21 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
@@ -27,6 +32,7 @@ type Params struct {
|
||||
AttentionHeads int `json:"num_attention_heads"` // n_head
|
||||
KeyValHeads int `json:"num_key_value_heads"`
|
||||
NormEPS float64 `json:"rms_norm_eps"`
|
||||
RopeFreqBase float64 `json:"rope_theta"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
HeadDimension int `json:"head_dim"`
|
||||
@@ -40,45 +46,157 @@ type ByteOrder interface {
|
||||
binary.AppendByteOrder
|
||||
}
|
||||
|
||||
type MetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
}
|
||||
|
||||
type ModelFormat interface {
|
||||
GetLayerName(string) (string, error)
|
||||
GetTensors(string, *Params) ([]llm.Tensor, error)
|
||||
GetParams(string) (*Params, error)
|
||||
GetModelArch(string, string, *Params) (ModelArch, error)
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
Path string
|
||||
Name string
|
||||
Params *Params
|
||||
Vocab *Vocab
|
||||
Tensors []llm.Tensor
|
||||
Format ModelFormat
|
||||
}
|
||||
|
||||
func GetModelFormat(dirname string) (ModelFormat, error) {
|
||||
files, err := filepath.Glob(filepath.Join(dirname, "*"))
|
||||
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data MetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := GetTensorName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("%v", t))
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
}
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, fn := range files {
|
||||
slog.Debug(fmt.Sprintf("file = %s", fn))
|
||||
if strings.HasSuffix(fn, ".safetensors") {
|
||||
return &SafetensorFormat{}, nil
|
||||
} else if strings.HasSuffix(fn, ".bin") {
|
||||
slog.Debug("model is torch")
|
||||
return &TorchFormat{}, nil
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = ReadSafeTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
func GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("couldn't determine model format")
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
// Details on gguf's tokenizer can be found at:
|
||||
@@ -89,7 +207,7 @@ type Vocab struct {
|
||||
Types []int32
|
||||
}
|
||||
|
||||
func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
|
||||
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
|
||||
if err != nil {
|
||||
@@ -169,8 +287,8 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
}
|
||||
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
|
||||
|
||||
if params.VocabSize > len(v.Tokens) {
|
||||
missingTokens := params.VocabSize - len(v.Tokens)
|
||||
if vocabSize > len(v.Tokens) {
|
||||
missingTokens := vocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
@@ -181,3 +299,136 @@ func LoadSentencePieceTokens(dirpath string, params *Params) (*Vocab, error) {
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func GetTensorName(n string) (string, error) {
|
||||
tMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
v, ok := tMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
filename string
|
||||
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
|
||||
@@ -65,14 +65,13 @@ func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
}
|
||||
|
||||
func (m *GemmaModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("Total tensors: %d", len(t)))
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
for _, l := range t {
|
||||
if strings.HasSuffix(l.Name, "norm.weight") {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
@@ -86,7 +85,7 @@ func (m *GemmaModel) GetTensors() error {
|
||||
}
|
||||
|
||||
func (m *GemmaModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
176
convert/llama.go
176
convert/llama.go
@@ -1,176 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type LlamaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func llamaLayerHandler(w io.Writer, r torchWriterTo) error {
|
||||
slog.Debug(fmt.Sprintf("repacking layer '%s'", r.t.Name))
|
||||
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
|
||||
var err error
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("heads = %d", heads))
|
||||
|
||||
tData, err = llamaRepack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func llamaRepack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
slog.Debug(fmt.Sprintf("setting handler for: %s", l.Name))
|
||||
wt := l.WriterTo.(torchWriterTo)
|
||||
wt.handler = llamaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) LoadVocab() error {
|
||||
var v *Vocab
|
||||
var err error
|
||||
|
||||
slog.Debug("loading vocab")
|
||||
v, err = LoadSentencePieceTokens(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Debug("vocab loaded")
|
||||
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *LlamaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.vocab_size": uint32(len(m.Vocab.Tokens)),
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("gguf file = %s", f.Name()))
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
@@ -97,7 +97,7 @@ func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
}
|
||||
|
||||
func (m *MistralModel) GetTensors() error {
|
||||
t, err := m.Format.GetTensors(m.Path, m.Params)
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -124,7 +124,7 @@ func (m *MistralModel) GetTensors() error {
|
||||
}
|
||||
|
||||
func (m *MistralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params)
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -144,6 +144,7 @@ func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"llama.rope.freq_base": float32(m.Params.RopeFreqBase),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
|
||||
@@ -1,304 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
filename string
|
||||
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
type tensorMetaData struct {
|
||||
Type string `mapstructure:"dtype"`
|
||||
Shape []int `mapstructure:"shape"`
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
type SafetensorFormat struct{}
|
||||
|
||||
func (m *SafetensorFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting tensor data")
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = m.readTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("all tensors = %d", len(tensors)))
|
||||
return tensors, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) readTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
for k := range parsed {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
|
||||
slices.Sort(keys)
|
||||
|
||||
slog.Info("converting layers")
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, k := range keys {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data tensorMetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
slog.Error("couldn't decode properly")
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("metadata = %#v", data))
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(data.Shape) {
|
||||
case 0:
|
||||
// metadata
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(data.Shape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(data.Shape[0] * data.Shape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := m.GetLayerName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range data.Shape {
|
||||
shape[i] = uint64(data.Shape[i])
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("total tensors for file = %d", len(tensors)))
|
||||
slog.Debug(fmt.Sprintf("offset = %d", offset))
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var params Params
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
tMap := map[string]string{
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range tMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *SafetensorFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
286
convert/torch.go
286
convert/torch.go
@@ -1,286 +0,0 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/nlpodyssey/gopickle/pytorch"
|
||||
"github.com/nlpodyssey/gopickle/types"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type torchWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
storage pytorch.StorageInterface
|
||||
handler func(w io.Writer, r torchWriterTo) error
|
||||
}
|
||||
|
||||
type TorchFormat struct{}
|
||||
|
||||
func (tf *TorchFormat) GetTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
slog.Debug("getting torch tensors")
|
||||
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "pytorch_model-*.bin"))
|
||||
if err != nil {
|
||||
slog.Error("didn't find any torch files")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
|
||||
var tensors []llm.Tensor
|
||||
for _, fn := range files {
|
||||
m, err := pytorch.Load(fn)
|
||||
if err != nil {
|
||||
slog.Error(fmt.Sprintf("error unpickling: %q", err))
|
||||
return []llm.Tensor{}, err
|
||||
}
|
||||
|
||||
for _, k := range m.(*types.Dict).Keys() {
|
||||
if strings.HasSuffix(k.(string), "self_attn.rotary_emb.inv_freq") {
|
||||
continue
|
||||
}
|
||||
|
||||
t, _ := m.(*types.Dict).Get(k)
|
||||
tshape := t.(*pytorch.Tensor).Size
|
||||
|
||||
var size uint64
|
||||
var kind uint32
|
||||
switch len(tshape) {
|
||||
case 0:
|
||||
continue
|
||||
case 1:
|
||||
// convert to float32
|
||||
kind = 0
|
||||
size = uint64(tshape[0] * 4)
|
||||
case 2:
|
||||
// convert to float16
|
||||
kind = 1
|
||||
size = uint64(tshape[0] * tshape[1] * 2)
|
||||
}
|
||||
|
||||
ggufName, err := tf.GetLayerName(k.(string))
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return nil, err
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("finding name for '%s' -> '%s'", k.(string), ggufName))
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
for i := range tshape {
|
||||
shape[i] = uint64(tshape[i])
|
||||
}
|
||||
|
||||
tensor := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset, // calculate the offset
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
tensor.WriterTo = torchWriterTo{
|
||||
t: &tensor,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
storage: t.(*pytorch.Tensor).Source,
|
||||
}
|
||||
|
||||
tensors = append(tensors, tensor)
|
||||
offset += size
|
||||
}
|
||||
}
|
||||
|
||||
return tensors, nil
|
||||
|
||||
}
|
||||
|
||||
func getAltParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "params.json"))
|
||||
if err != nil {
|
||||
slog.Error("no params.json")
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
type TorchParams struct {
|
||||
HiddenSize int `json:"dim"`
|
||||
AttentionHeads int `json:"n_heads"`
|
||||
KeyValHeads int `json:"n_kv_heads"`
|
||||
HiddenLayers int `json:"n_layers"`
|
||||
RopeTheta int `json:"rope_theta"`
|
||||
NormEPS float64 `json:"norm_eps"`
|
||||
}
|
||||
|
||||
var tparams TorchParams
|
||||
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(&tparams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params := &Params{
|
||||
HiddenSize: tparams.HiddenSize,
|
||||
AttentionHeads: tparams.AttentionHeads,
|
||||
KeyValHeads: tparams.KeyValHeads,
|
||||
HiddenLayers: tparams.HiddenLayers,
|
||||
NormEPS: tparams.NormEPS,
|
||||
}
|
||||
|
||||
switch {
|
||||
case tparams.RopeTheta == 1000000:
|
||||
// Codellama
|
||||
params.ContextSize = 16384
|
||||
case tparams.NormEPS == 1e-06:
|
||||
// llama2
|
||||
slog.Debug("Found llama2 - setting context size to 4096")
|
||||
params.ContextSize = 4096
|
||||
default:
|
||||
params.ContextSize = 2048
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return params, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetParams(dirpath string) (*Params, error) {
|
||||
f, err := os.Open(filepath.Join(dirpath, "config.json"))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
// try params.json instead
|
||||
return getAltParams(dirpath)
|
||||
} else {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var params Params
|
||||
d := json.NewDecoder(f)
|
||||
err = d.Decode(¶ms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetLayerName(n string) (string, error) {
|
||||
directMap := map[string]string{
|
||||
"tok_embeddings.weight": "token_embd.weight",
|
||||
"output.weight": "output.weight",
|
||||
"norm.weight": "output_norm.weight",
|
||||
"rope.freqs": "rope_freqs.weight",
|
||||
"model.embed_tokens.weight": "token_embd.weight",
|
||||
"lm_head.weight": "output.weight",
|
||||
"model.norm.weight": "output_norm.weight",
|
||||
}
|
||||
|
||||
lMap := map[string]string{
|
||||
"layers.(\\d+).attention_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).attention_output_norm.weight": "blk.$1.attn_norm.weight",
|
||||
"layers.(\\d+).feed_forward.w2.weight": "blk.$1.ffn_down.weight",
|
||||
"layers.(\\d+).feed_forward.w1.weight": "blk.$1.ffn_gate.weight",
|
||||
"layers.(\\d+).feed_forward.w3.weight": "blk.$1.ffn_up.weight",
|
||||
"layers.(\\d+).ffn_norm.weight": "blk.$1.ffn_norm.weight",
|
||||
"layers.(\\d+).attention.wk.weight": "blk.$1.attn_k.weight",
|
||||
"layers.(\\d+).attention.wo.weight": "blk.$1.attn_output.weight",
|
||||
"layers.(\\d+).attention.wq.weight": "blk.$1.attn_q.weight",
|
||||
"layers.(\\d+).attention.wv.weight": "blk.$1.attn_v.weight",
|
||||
"model.layers.(\\d+).input_layernorm.weight": "blk.$1.attn_norm.weight",
|
||||
"model.layers.(\\d+).mlp.down_proj.weight": "blk.$1.ffn_down.weight",
|
||||
"model.layers.(\\d+).mlp.gate_proj.weight": "blk.$1.ffn_gate.weight",
|
||||
"model.layers.(\\d+).mlp.up_proj.weight": "blk.$1.ffn_up.weight",
|
||||
"model.layers.(\\d+).post_attention_layernorm.weight": "blk.$1.ffn_norm.weight",
|
||||
"model.layers.(\\d+).self_attn.k_proj.weight": "blk.$1.attn_k.weight",
|
||||
"model.layers.(\\d+).self_attn.o_proj.weight": "blk.$1.attn_output.weight",
|
||||
"model.layers.(\\d+).self_attn.q_proj.weight": "blk.$1.attn_q.weight",
|
||||
"model.layers.(\\d+).self_attn.v_proj.weight": "blk.$1.attn_v.weight",
|
||||
}
|
||||
|
||||
v, ok := directMap[n]
|
||||
if ok {
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// quick hack to rename the layers to gguf format
|
||||
for k, v := range lMap {
|
||||
re := regexp.MustCompile(k)
|
||||
newName := re.ReplaceAllString(n, v)
|
||||
if newName != n {
|
||||
return newName, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func (r torchWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r)
|
||||
}
|
||||
|
||||
switch r.storage.(type) {
|
||||
case *pytorch.FloatStorage:
|
||||
slog.Warn(fmt.Sprintf("unexpected storage found for layer '%s'; skipping", r.t.Name))
|
||||
return 0, nil
|
||||
case *pytorch.HalfStorage:
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
slog.Debug(fmt.Sprintf("%35s F32 (%d)", r.t.Name, len(data)))
|
||||
if err := binary.Write(w, r.bo, data); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
data := r.storage.(*pytorch.HalfStorage).Data
|
||||
tData := make([]uint16, len(data))
|
||||
for cnt, v := range data {
|
||||
tData[cnt] = uint16(float16.Fromfloat32(v))
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("%35s F16 (%d)", r.t.Name, len(tData)))
|
||||
if err := binary.Write(w, r.bo, tData); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (m *TorchFormat) GetModelArch(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "LlamaForCausalLM":
|
||||
return &LlamaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
Format: m,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
@@ -394,6 +394,7 @@ Advanced parameters (optional):
|
||||
|
||||
- `format`: the format to return a response in. Currently the only accepted value is `json`
|
||||
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
|
||||
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
|
||||
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
|
||||
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
|
||||
|
||||
|
||||
@@ -139,6 +139,9 @@ PARAMETER <parameter> <parametervalue>
|
||||
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
|
||||
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
|
||||
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
|
||||
| num_gqa | The number of GQA groups in the transformer layer. Required for some models, for example it is 8 for llama2:70b | int | num_gqa 1 |
|
||||
| num_gpu | The number of layers to send to the GPU(s). On macOS it defaults to 1 to enable metal support, 0 to disable. | int | num_gpu 50 |
|
||||
| num_thread | Sets the number of threads to use during computation. By default, Ollama will detect this for optimal performance. It is recommended to set this value to the number of physical CPU cores your system has (as opposed to the logical number of cores). | int | num_thread 8 |
|
||||
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
|
||||
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |
|
||||
| temperature | The temperature of the model. Increasing the temperature will make the model answer more creatively. (Default: 0.8) | float | temperature 0.7 |
|
||||
|
||||
@@ -18,7 +18,7 @@ const ollama = new Ollama({
|
||||
model: "llama2",
|
||||
});
|
||||
|
||||
const answer = await ollama.invoke(`why is the sky blue?`);
|
||||
const answer = await ollama.call(`why is the sky blue?`);
|
||||
|
||||
console.log(answer);
|
||||
```
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
messages := []api.Message{
|
||||
api.Message{
|
||||
Role: "system",
|
||||
Content: "Provide very brief, concise responses",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "Name some unusual animals",
|
||||
},
|
||||
api.Message{
|
||||
Role: "assistant",
|
||||
Content: "Monotreme, platypus, echidna",
|
||||
},
|
||||
api.Message{
|
||||
Role: "user",
|
||||
Content: "which of these is the most dangerous?",
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
req := &api.ChatRequest{
|
||||
Model: "llama2",
|
||||
Messages: messages,
|
||||
}
|
||||
|
||||
respFunc := func(resp api.ChatResponse) error {
|
||||
fmt.Print(resp.Message.Content)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Chat(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
// By default, GenerateRequest is streaming.
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "gemma",
|
||||
Prompt: "how many planets are there?",
|
||||
|
||||
// set streaming to false
|
||||
Stream: new(bool),
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// Only print the response here; GenerateResponse has a number of other
|
||||
// interesting fields you want to examine.
|
||||
fmt.Println(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -1,47 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if len(os.Args) <= 1 {
|
||||
log.Fatal("usage: <image name>")
|
||||
}
|
||||
|
||||
imgData, err := os.ReadFile(os.Args[1])
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
req := &api.GenerateRequest{
|
||||
Model: "llava",
|
||||
Prompt: "describe this image",
|
||||
Images: []api.ImageData{imgData},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
respFunc := func(resp api.GenerateResponse) error {
|
||||
// In streaming mode, responses are partial so we call fmt.Print (and not
|
||||
// Println) in order to avoid spurious newlines being introduced. The
|
||||
// model will insert its own newlines if it wants.
|
||||
fmt.Print(resp.Response)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Generate(ctx, req, respFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
fmt.Println()
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func main() {
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
req := &api.PullRequest{
|
||||
Model: "mistral",
|
||||
}
|
||||
progressFunc := func(resp api.ProgressResponse) error {
|
||||
fmt.Printf("Progress: status=%v, total=%v, completed=%v\n", resp.Status, resp.Total, resp.Completed)
|
||||
return nil
|
||||
}
|
||||
|
||||
err = client.Pull(ctx, req, progressFunc)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
@@ -50,7 +50,7 @@ func HumanBytes(b int64) string {
|
||||
}
|
||||
}
|
||||
|
||||
func HumanBytes2(b uint64) string {
|
||||
func HumanBytes2(b int64) string {
|
||||
switch {
|
||||
case b >= MebiByte:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||
|
||||
23
go.mod
23
go.mod
@@ -10,7 +10,7 @@ require (
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
@@ -20,8 +20,9 @@ require (
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/nlpodyssey/gopickle v0.3.0
|
||||
github.com/minio/minio-go/v7 v7.0.69
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
kr.dev/diff v0.3.0
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -29,16 +30,24 @@ require (
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.0.8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v1.12.0 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.8.2 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -56,7 +65,7 @@ require (
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -66,11 +75,11 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0
|
||||
golang.org/x/term v0.17.0
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
|
||||
49
go.sum
49
go.sum
@@ -26,6 +26,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -86,8 +88,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -95,9 +97,12 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
@@ -115,6 +120,12 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
|
||||
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -122,8 +133,6 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
|
||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
|
||||
github.com/nlpodyssey/gopickle v0.3.0 h1:BLUE5gxFLyyNOPzlXxt6GoHEMMxD0qhsE4p0CIQyoLw=
|
||||
github.com/nlpodyssey/gopickle v0.3.0/go.mod h1:f070HJ/yR+eLi5WmM1OXJEGaTpuJEUiib19olXgYha0=
|
||||
github.com/olekukonko/tablewriter v0.0.5 h1:P2Ga83D34wi1o9J6Wh1mRuqd4mF/x/lgBS7N7AbDhec=
|
||||
github.com/olekukonko/tablewriter v0.0.5/go.mod h1:hPp6KlRPjbx+hW8ykQs1w3UBbZlj6HuIJcUGPhkA7kY=
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9 h1:DV4iXjNn6fGeDl1AkZ1I0QB/0DBjrc7kPpxHrmuDzW4=
|
||||
@@ -131,6 +140,7 @@ github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6u
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -140,8 +150,11 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
@@ -183,8 +196,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -207,8 +220,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -228,13 +241,13 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
@@ -294,6 +307,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -305,4 +320,6 @@ gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
|
||||
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -243,7 +243,7 @@ func getCPUMem() (memInfo, error) {
|
||||
return ret, nil
|
||||
}
|
||||
|
||||
func CheckVRAM() (uint64, error) {
|
||||
func CheckVRAM() (int64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -251,11 +251,11 @@ func CheckVRAM() (uint64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return uint64(avail), nil
|
||||
return avail, nil
|
||||
}
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
return gpuInfo.FreeMemory, nil
|
||||
return int64(gpuInfo.FreeMemory), nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
|
||||
@@ -17,7 +17,7 @@ import (
|
||||
)
|
||||
|
||||
// CheckVRAM returns the free VRAM in bytes on Linux machines with NVIDIA GPUs
|
||||
func CheckVRAM() (uint64, error) {
|
||||
func CheckVRAM() (int64, error) {
|
||||
userLimit := os.Getenv("OLLAMA_MAX_VRAM")
|
||||
if userLimit != "" {
|
||||
avail, err := strconv.ParseInt(userLimit, 10, 64)
|
||||
@@ -25,15 +25,15 @@ func CheckVRAM() (uint64, error) {
|
||||
return 0, fmt.Errorf("Invalid OLLAMA_MAX_VRAM setting %s: %s", userLimit, err)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("user override OLLAMA_MAX_VRAM=%d", avail))
|
||||
return uint64(avail), nil
|
||||
return avail, nil
|
||||
}
|
||||
|
||||
if runtime.GOARCH == "amd64" {
|
||||
// gpu not supported, this may not be metal
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
return uint64(C.getRecommendedMaxVRAM()), nil
|
||||
recommendedMaxVRAM := int64(C.getRecommendedMaxVRAM())
|
||||
return recommendedMaxVRAM, nil
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
@@ -53,7 +53,7 @@ func GetGPUInfo() GpuInfo {
|
||||
|
||||
func getCPUMem() (memInfo, error) {
|
||||
return memInfo{
|
||||
TotalMemory: uint64(C.getPhysicalMemory()),
|
||||
TotalMemory: 0,
|
||||
FreeMemory: 0,
|
||||
DeviceCount: 0,
|
||||
}, nil
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
#import <Metal/Metal.h>
|
||||
#include <stdint.h>
|
||||
uint64_t getRecommendedMaxVRAM();
|
||||
uint64_t getPhysicalMemory();
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
// go:build darwin
|
||||
//go:build darwin
|
||||
#include "gpu_info_darwin.h"
|
||||
|
||||
uint64_t getRecommendedMaxVRAM() {
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
uint64_t getRecommendedMaxVRAM()
|
||||
{
|
||||
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
|
||||
uint64_t result = device.recommendedMaxWorkingSetSize;
|
||||
CFRelease(device);
|
||||
return result;
|
||||
}
|
||||
|
||||
uint64_t getPhysicalMemory() {
|
||||
return [[NSProcessInfo processInfo] physicalMemory];
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ type GpuInfo struct {
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory uint64 `json:"-"`
|
||||
MinimumMemory int64 `json:"-"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
//go:build integration
|
||||
|
||||
package integration
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
func TestContextExhaustion(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute) // TODO maybe shorter?
|
||||
defer cancel()
|
||||
// Set up the test data
|
||||
req := api.GenerateRequest{
|
||||
Model: "llama2",
|
||||
Prompt: "Write me a story with a ton of emojis?",
|
||||
Stream: &stream,
|
||||
Options: map[string]interface{}{
|
||||
"temperature": 0,
|
||||
"seed": 123,
|
||||
"num_ctx": 128,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"once", "upon", "lived"})
|
||||
}
|
||||
@@ -15,6 +15,10 @@ import (
|
||||
// TODO - this would ideally be in the llm package, but that would require some refactoring of interfaces in the server
|
||||
// package to avoid circular dependencies
|
||||
|
||||
// WARNING - these tests will fail on mac if you don't manually copy ggml-metal.metal to this dir (./server)
|
||||
//
|
||||
// TODO - Fix this ^^
|
||||
|
||||
var (
|
||||
stream = false
|
||||
req = [2]api.GenerateRequest{
|
||||
|
||||
@@ -18,7 +18,7 @@ sign() {
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DLLAMA_METAL_MACOSX_VERSION_MIN=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
@@ -41,7 +41,7 @@ case "${GOARCH}" in
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
|
||||
#
|
||||
@@ -53,7 +53,7 @@ case "${GOARCH}" in
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
|
||||
#
|
||||
@@ -66,7 +66,7 @@ case "${GOARCH}" in
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
;;
|
||||
"arm64")
|
||||
@@ -74,17 +74,17 @@ case "${GOARCH}" in
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.3 -DCMAKE_SYSTEM_NAME=Darwin -DBUILD_SHARED_LIBS=off -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${BUILD_DIR}/bin/ollama_llama_server
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
;;
|
||||
*)
|
||||
|
||||
@@ -172,7 +172,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
# Disabling has minimal performance effect while maintaining compatibility.
|
||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUDA=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
|
||||
@@ -146,7 +146,7 @@ function compress {
|
||||
}
|
||||
|
||||
write-host "Compressing dlls..."
|
||||
$dlls = dir "${script:buildDir}/bin/*.dll"
|
||||
$binaries = dir "${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in $dlls) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
@@ -183,17 +183,9 @@ if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||
|
||||
# GCC build for direct linking into the Go binary
|
||||
init_vars
|
||||
# cmake will silently fallback to msvc compilers if mingw isn't in the path, so detect and fail fast
|
||||
# as we need this to be compiled by gcc for golang to be able to link with itx
|
||||
write-host "Checking for MinGW..."
|
||||
# error action ensures we exit on failure
|
||||
get-command gcc
|
||||
get-command mingw32-make
|
||||
$script:cmakeTargets = @("llama", "ggml")
|
||||
$script:cmakeDefs = @(
|
||||
"-G", "MinGW Makefiles"
|
||||
"-DCMAKE_C_COMPILER=gcc.exe",
|
||||
"-DCMAKE_CXX_COMPILER=g++.exe",
|
||||
"-DBUILD_SHARED_LIBS=off",
|
||||
"-DLLAMA_NATIVE=off",
|
||||
"-DLLAMA_AVX=off",
|
||||
@@ -242,7 +234,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUDA=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
build
|
||||
sign
|
||||
compress
|
||||
@@ -261,7 +253,6 @@ if ($null -ne $env:HIP_PATH) {
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
"-DCMAKE_CXX_COMPILER=clang++.exe",
|
||||
"-DLLAMA_HIPBLAS=on",
|
||||
"-DHIP_PLATFORM=amd",
|
||||
"-DLLAMA_AVX=on",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DCMAKE_POSITION_INDEPENDENT_CODE=on",
|
||||
|
||||
@@ -49,7 +49,7 @@ func (llm *ggla) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *ggla) Tensors() Tensors {
|
||||
func (llm *ggla) Tensors() []*Tensor {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
|
||||
115
llm/ggml.go
115
llm/ggml.go
@@ -13,6 +13,16 @@ type GGML struct {
|
||||
model
|
||||
}
|
||||
|
||||
func (ggml *GGML) LayerSize(prefix string) (n int64) {
|
||||
for _, t := range ggml.Tensors() {
|
||||
if strings.HasPrefix(t.Name, prefix) {
|
||||
n += int64(t.size())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
fileTypeF32 uint32 = iota
|
||||
fileTypeF16
|
||||
@@ -91,7 +101,7 @@ func fileType(fileType uint32) string {
|
||||
|
||||
type model interface {
|
||||
KV() KV
|
||||
Tensors() Tensors
|
||||
Tensors() []*Tensor
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
@@ -138,15 +148,15 @@ func (kv KV) HeadCount() uint64 {
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
if headCountKV := kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture())); headCountKV > 0 {
|
||||
return headCountKV
|
||||
}
|
||||
|
||||
return 1
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
return kv.HeadCount() / kv.HeadCountKV()
|
||||
if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
|
||||
return kv.HeadCount() / headCountKV
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
@@ -157,36 +167,6 @@ func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
type Tensors []*Tensor
|
||||
|
||||
func (ts Tensors) Layers() map[string]Layer {
|
||||
layers := make(map[string]Layer)
|
||||
for _, t := range ts {
|
||||
parts := strings.Split(t.Name, ".")
|
||||
if parts[0] == "blk" {
|
||||
parts = parts[1:]
|
||||
}
|
||||
|
||||
if _, ok := layers[parts[0]]; !ok {
|
||||
layers[parts[0]] = make(Layer)
|
||||
}
|
||||
|
||||
layers[parts[0]][strings.Join(parts[1:], ".")] = t
|
||||
}
|
||||
|
||||
return layers
|
||||
}
|
||||
|
||||
type Layer map[string]*Tensor
|
||||
|
||||
func (l Layer) size() (size uint64) {
|
||||
for _, t := range l {
|
||||
size += t.size()
|
||||
}
|
||||
|
||||
return size
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
Kind uint32 `json:"kind"`
|
||||
@@ -323,64 +303,3 @@ func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
model: model,
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
func (llm GGML) GraphSize(context, batch uint64) (partialOffload, fullOffload uint64) {
|
||||
embedding := llm.KV().EmbeddingLength()
|
||||
heads := llm.KV().HeadCount()
|
||||
headsKV := llm.KV().HeadCountKV()
|
||||
vocab := uint64(len(llm.KV()["tokenizer.ggml.tokens"].([]any)))
|
||||
|
||||
layers := llm.Tensors().Layers()
|
||||
|
||||
switch llm.KV().Architecture() {
|
||||
case "llama":
|
||||
fullOffload = 4 * batch * (1 + 4*embedding + context*(1+heads))
|
||||
|
||||
partialOffload = 4 * batch * embedding
|
||||
partialOffload += max(
|
||||
4*batch*(1+embedding+max(context, embedding))+embedding*embedding*9/16+4*context*(batch*heads+embedding/heads*headsKV),
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
)
|
||||
|
||||
if ffnGateWeight, ok := layers["0"]["ffn_gate.0.weight"]; ok {
|
||||
ffnGateWeight1 := ffnGateWeight.Shape[1]
|
||||
fullOffload = 4 * batch * (2 + 3*embedding + context*(1+heads) + 2*headsKV + ffnGateWeight1)
|
||||
partialOffload = max(
|
||||
4*batch*(3+embedding/heads*headsKV+embedding+context*(1+heads)+ffnGateWeight1)+(embedding*embedding+3*embedding*headsKV*ffnGateWeight1)*9/16,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+embedding*(6*context*headsKV/heads+embedding*9/16),
|
||||
)
|
||||
}
|
||||
case "gemma":
|
||||
fullOffload = 4 * batch * (embedding + vocab)
|
||||
partialOffload = 4*batch*(2*embedding+vocab+1) + embedding*vocab*105/128
|
||||
case "command-r":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(2+4*embedding+context*(1+heads)),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*batch*(1+2*embedding+context*(1+heads))+4*embedding*context+embedding*embedding*9/16,
|
||||
)
|
||||
case "qwen2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+2*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = max(
|
||||
4*batch*(embedding+vocab)+embedding*vocab*105/128,
|
||||
4*(batch*(1+2*embedding+context*(1+heads))+embedding*(1+context)),
|
||||
)
|
||||
case "phi2":
|
||||
fullOffload = max(
|
||||
4*batch*(embedding+vocab),
|
||||
4*batch*(1+4*embedding+context+context*heads),
|
||||
)
|
||||
|
||||
partialOffload = 4*batch*(2*embedding+vocab) + embedding*vocab*105/128
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
22
llm/gguf.go
22
llm/gguf.go
@@ -6,8 +6,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"log/slog"
|
||||
)
|
||||
|
||||
type containerGGUF struct {
|
||||
@@ -54,7 +52,6 @@ func (c *containerGGUF) Decode(rs io.ReadSeeker) (model, error) {
|
||||
}
|
||||
|
||||
model := newGGUF(c)
|
||||
slog.Debug(fmt.Sprintf("model = %#v", model))
|
||||
if err := model.Decode(rs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -112,7 +109,7 @@ func (llm *gguf) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *gguf) Tensors() Tensors {
|
||||
func (llm *gguf) Tensors() []*Tensor {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
@@ -190,8 +187,6 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
llm.kv[k] = v
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("general.architecture = %s", llm.kv["general.architecture"]))
|
||||
|
||||
// decode tensors
|
||||
for i := 0; uint64(i) < llm.numTensor(); i++ {
|
||||
name, err := readGGUFString(llm, rs)
|
||||
@@ -248,7 +243,7 @@ func (llm *gguf) Decode(rs io.ReadSeeker) error {
|
||||
}
|
||||
|
||||
padding := llm.padding(offset, int64(alignment))
|
||||
if _, err := rs.Seek(padding-offset, io.SeekCurrent); err != nil {
|
||||
if _, err := rs.Seek(padding, io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -456,7 +451,6 @@ var ggufKVOrder = map[string][]string{
|
||||
"llama": {
|
||||
"general.architecture",
|
||||
"general.name",
|
||||
"llama.vocab_size",
|
||||
"llama.context_length",
|
||||
"llama.embedding_length",
|
||||
"llama.block_count",
|
||||
@@ -515,17 +509,11 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
return err
|
||||
}
|
||||
|
||||
kvCheck := make(map[string]bool)
|
||||
for k := range kv {
|
||||
kvCheck[k] = false
|
||||
}
|
||||
|
||||
for _, k := range ggufKVOrder["llama"] {
|
||||
v, ok := kv[k]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
kvCheck[k] = true
|
||||
|
||||
if err := binary.Write(ws, llm.ByteOrder, uint64(len(k))); err != nil {
|
||||
return err
|
||||
@@ -579,12 +567,6 @@ func (llm *gguf) Encode(ws io.WriteSeeker, kv KV, tensors []Tensor) error {
|
||||
}
|
||||
}
|
||||
|
||||
for k, v := range kvCheck {
|
||||
if !v {
|
||||
return fmt.Errorf("Didn't know how to write kv %s", k)
|
||||
}
|
||||
}
|
||||
|
||||
for _, tensor := range tensors {
|
||||
if err := binary.Write(ws, llm.ByteOrder, uint64(len(tensor.Name))); err != nil {
|
||||
return err
|
||||
|
||||
Submodule llm/llama.cpp updated: 7593639ce3...37e7854c10
71
llm/llm.go
71
llm/llm.go
@@ -6,81 +6,10 @@ package llm
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include <stdlib.h>
|
||||
// #include "llama.h"
|
||||
import "C"
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
func Quantize(infile, outfile, filetype string) error {
|
||||
cinfile := C.CString(infile)
|
||||
defer C.free(unsafe.Pointer(cinfile))
|
||||
|
||||
coutfile := C.CString(outfile)
|
||||
defer C.free(unsafe.Pointer(coutfile))
|
||||
|
||||
params := C.llama_model_quantize_default_params()
|
||||
params.nthread = -1
|
||||
|
||||
switch filetype {
|
||||
case "F32":
|
||||
params.ftype = fileTypeF32
|
||||
case "F16":
|
||||
params.ftype = fileTypeF16
|
||||
case "Q4_0":
|
||||
params.ftype = fileTypeQ4_0
|
||||
case "Q4_1":
|
||||
params.ftype = fileTypeQ4_1
|
||||
case "Q4_1_F16":
|
||||
params.ftype = fileTypeQ4_1_F16
|
||||
case "Q8_0":
|
||||
params.ftype = fileTypeQ8_0
|
||||
case "Q5_0":
|
||||
params.ftype = fileTypeQ5_0
|
||||
case "Q5_1":
|
||||
params.ftype = fileTypeQ5_1
|
||||
case "Q2_K":
|
||||
params.ftype = fileTypeQ2_K
|
||||
case "Q3_K_S":
|
||||
params.ftype = fileTypeQ3_K_S
|
||||
case "Q3_K_M":
|
||||
params.ftype = fileTypeQ3_K_M
|
||||
case "Q3_K_L":
|
||||
params.ftype = fileTypeQ3_K_L
|
||||
case "Q4_K_S":
|
||||
params.ftype = fileTypeQ4_K_S
|
||||
case "Q4_K_M":
|
||||
params.ftype = fileTypeQ4_K_M
|
||||
case "Q5_K_S":
|
||||
params.ftype = fileTypeQ5_K_S
|
||||
case "Q5_K_M":
|
||||
params.ftype = fileTypeQ5_K_M
|
||||
case "Q6_K":
|
||||
params.ftype = fileTypeQ6_K
|
||||
case "IQ2_XXS":
|
||||
params.ftype = fileTypeIQ2_XXS
|
||||
case "IQ2_XS":
|
||||
params.ftype = fileTypeIQ2_XS
|
||||
case "Q2_K_S":
|
||||
params.ftype = fileTypeQ2_K_S
|
||||
case "Q3_K_XS":
|
||||
params.ftype = fileTypeQ3_K_XS
|
||||
case "IQ3_XXS":
|
||||
params.ftype = fileTypeIQ3_XXS
|
||||
default:
|
||||
return fmt.Errorf("unknown filetype: %s", filetype)
|
||||
}
|
||||
|
||||
if retval := C.llama_model_quantize(cinfile, coutfile, ¶ms); retval != 0 {
|
||||
return fmt.Errorf("llama_model_quantize: %d", retval)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
129
llm/server.go
129
llm/server.go
@@ -17,6 +17,7 @@ import (
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -35,7 +36,15 @@ type LlamaServer struct {
|
||||
options api.Options
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
|
||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -56,83 +65,66 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
memoryAvailable, _ := gpu.CheckVRAM()
|
||||
availableMemory, _ := gpu.CheckVRAM()
|
||||
info := gpu.GetGPUInfo()
|
||||
|
||||
memoryMinimum := info.MinimumMemory
|
||||
usedMemory := info.MinimumMemory
|
||||
for _, projector := range projectors {
|
||||
memoryMinimum += projectorMemoryRequirements(projector)
|
||||
usedMemory += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||
var kv uint64 = 2 * 2 * uint64(opts.NumCtx) * ggml.KV().BlockCount() * ggml.KV().EmbeddingLength() / ggml.KV().HeadCount() * ggml.KV().HeadCountKV()
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
|
||||
|
||||
graphPartialOffload, graphFullOffload := ggml.GraphSize(uint64(opts.NumCtx), uint64(min(opts.NumCtx, opts.NumBatch)))
|
||||
if graphPartialOffload == 0 {
|
||||
graphPartialOffload = ggml.KV().GQA() * kv / 6
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.KV().GQA()) * kv / 6
|
||||
usedMemory += graph
|
||||
|
||||
if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
|
||||
if graphFullOffload == 0 {
|
||||
graphFullOffload = graphPartialOffload
|
||||
}
|
||||
requiredMemory := usedMemory
|
||||
|
||||
// memoryRequiredTotal represents the memory required for full GPU offloading (all layers)
|
||||
memoryRequiredTotal := memoryMinimum + graphFullOffload
|
||||
|
||||
// memoryRequiredPartial represents the memory required for partial GPU offloading (n > 0, n < layers)
|
||||
memoryRequiredPartial := memoryMinimum + graphPartialOffload
|
||||
|
||||
if info.Library != "metal" {
|
||||
if memoryRequiredPartial > memoryAvailable {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
}
|
||||
|
||||
var layerCount int
|
||||
layers := ggml.Tensors().Layers()
|
||||
var layers int
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
memoryLayer := layers[fmt.Sprintf("%d", i)].size()
|
||||
layerMemory := ggml.LayerSize(fmt.Sprintf("blk.%d.", i)) + kv/int64(ggml.KV().BlockCount())
|
||||
requiredMemory += layerMemory
|
||||
|
||||
// KV is proportional to the number of layers
|
||||
memoryLayer += kv / ggml.KV().BlockCount()
|
||||
|
||||
memoryRequiredTotal += memoryLayer
|
||||
if memoryAvailable > memoryRequiredPartial+memoryLayer {
|
||||
memoryRequiredPartial += memoryLayer
|
||||
layerCount++
|
||||
if availableMemory > usedMemory+layerMemory && (opts.NumGPU < 0 || layers < opts.NumGPU) {
|
||||
usedMemory += layerMemory
|
||||
layers++
|
||||
}
|
||||
}
|
||||
|
||||
memoryLayerOutput := layers["output"].size()
|
||||
memoryRequiredTotal += memoryLayerOutput
|
||||
memOutputLayer := ggml.LayerSize("output.")
|
||||
requiredMemory += memOutputLayer
|
||||
|
||||
if info.Library == "metal" && memoryRequiredTotal > info.TotalMemory {
|
||||
// disable partial offloading when model is greater than total system memory
|
||||
opts.NumGPU = 0
|
||||
} else if memoryAvailable > memoryRequiredTotal {
|
||||
layerCount = int(ggml.KV().BlockCount()) + 1
|
||||
memoryRequiredPartial = memoryRequiredTotal
|
||||
}
|
||||
|
||||
if opts.NumGPU < 0 {
|
||||
opts.NumGPU = layerCount
|
||||
// only offload output layer if all repeating layers are offloaded
|
||||
if layers >= int(ggml.KV().BlockCount()) && availableMemory > usedMemory+memOutputLayer {
|
||||
usedMemory += memOutputLayer
|
||||
layers++
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"offload to gpu",
|
||||
"reallayers", opts.NumGPU,
|
||||
"layers", layerCount,
|
||||
"required", format.HumanBytes2(memoryRequiredTotal),
|
||||
"used", format.HumanBytes2(memoryRequiredPartial),
|
||||
"available", format.HumanBytes2(memoryAvailable),
|
||||
"layers", layers,
|
||||
"required", format.HumanBytes2(requiredMemory),
|
||||
"used", format.HumanBytes2(usedMemory),
|
||||
"available", format.HumanBytes2(availableMemory),
|
||||
"kv", format.HumanBytes2(kv),
|
||||
"fulloffload", format.HumanBytes2(graphFullOffload),
|
||||
"partialoffload", format.HumanBytes2(graphPartialOffload),
|
||||
"graph", format.HumanBytes2(graph),
|
||||
)
|
||||
|
||||
if opts.NumGPU < 0 && info.Library != "cpu" {
|
||||
opts.NumGPU = layers
|
||||
}
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
@@ -179,6 +171,14 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyBase > 0 {
|
||||
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyScale > 0 {
|
||||
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
|
||||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
@@ -276,6 +276,12 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
if err = s.waitUntilRunning(); err != nil {
|
||||
slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
s.Close()
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
@@ -283,7 +289,7 @@ func NewLlamaServer(model string, adapters, projectors []string, opts api.Option
|
||||
return nil, finalErr
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) uint64 {
|
||||
func projectorMemoryRequirements(filename string) int64 {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
@@ -295,12 +301,18 @@ func projectorMemoryRequirements(filename string) uint64 {
|
||||
return 0
|
||||
}
|
||||
|
||||
var mem uint64
|
||||
for _, layer := range ggml.Tensors().Layers() {
|
||||
mem += layer.size()
|
||||
prefixes := make(map[string]struct{})
|
||||
for _, layer := range ggml.Tensors() {
|
||||
parts := strings.Split(layer.Name, ".")
|
||||
prefixes[strings.Join(parts[:2], ".")] = struct{}{}
|
||||
}
|
||||
|
||||
return mem
|
||||
var ask int64
|
||||
for prefix := range prefixes {
|
||||
ask += ggml.LayerSize(prefix)
|
||||
}
|
||||
|
||||
return ask
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
@@ -376,10 +388,9 @@ func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) WaitUntilRunning() error {
|
||||
func (s *LlamaServer) waitUntilRunning() error {
|
||||
start := time.Now()
|
||||
// TODO we need to wire up a better way to detect hangs during model load and startup of the server
|
||||
expiresAt := time.Now().Add(10 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ go build .
|
||||
Then run the desktop app with `npm start`:
|
||||
|
||||
```
|
||||
cd macapp
|
||||
cd app
|
||||
npm install
|
||||
npm start
|
||||
```
|
||||
|
||||
@@ -247,8 +247,7 @@ func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w
|
||||
}
|
||||
|
||||
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
|
||||
const msg = "%s part %d stalled; retrying. If this persists, press ctrl-c to exit, then 'ollama pull' to find a faster connection."
|
||||
slog.Info(fmt.Sprintf(msg, b.Digest[7:19], part.N))
|
||||
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
|
||||
// reset last updated
|
||||
part.lastUpdated = time.Time{}
|
||||
return errPartStalled
|
||||
|
||||
@@ -284,7 +284,7 @@ func realpath(mfDir, from string) string {
|
||||
return abspath
|
||||
}
|
||||
|
||||
func CreateModel(ctx context.Context, name, modelFileDir, quantization string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
|
||||
deleteMap := make(map[string]struct{})
|
||||
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
|
||||
for _, layer := range append(manifest.Layers, manifest.Config) {
|
||||
@@ -322,7 +322,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
|
||||
pathName := realpath(modelFileDir, c.Args)
|
||||
|
||||
ggufName, err := convertModel(name, pathName, fn)
|
||||
ggufName, err := convertSafetensors(name, pathName, fn)
|
||||
if err != nil {
|
||||
var pathErr *fs.PathError
|
||||
switch {
|
||||
@@ -337,27 +337,8 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
|
||||
if ggufName != "" {
|
||||
pathName = ggufName
|
||||
slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
|
||||
defer os.RemoveAll(ggufName)
|
||||
|
||||
if quantization != "" {
|
||||
quantization = strings.ToUpper(quantization)
|
||||
fn(api.ProgressResponse{Status: fmt.Sprintf("quantizing %s model to %s", "F16", quantization)})
|
||||
tempfile, err := os.CreateTemp(filepath.Dir(ggufName), quantization)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer os.RemoveAll(tempfile.Name())
|
||||
|
||||
if err := llm.Quantize(ggufName, tempfile.Name(), quantization); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tempfile.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pathName = tempfile.Name()
|
||||
}
|
||||
}
|
||||
|
||||
bin, err := os.Open(pathName)
|
||||
@@ -633,7 +614,7 @@ func CreateModel(ctx context.Context, name, modelFileDir, quantization string, c
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
r, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -668,22 +649,17 @@ func convertModel(name, path string, fn func(resp api.ProgressResponse)) (string
|
||||
rc.Close()
|
||||
}
|
||||
|
||||
mf, err := convert.GetModelFormat(tempDir)
|
||||
params, err := convert.GetParams(tempDir)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
params, err := mf.GetParams(tempDir)
|
||||
mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mArch, err := mf.GetModelArch(name, tempDir, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "processing tensors"})
|
||||
fn(api.ProgressResponse{Status: "processing safetensors"})
|
||||
if err := mArch.GetTensors(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -68,18 +68,6 @@ var loaded struct {
|
||||
|
||||
var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
func unload() {
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(c, 10*time.Second)
|
||||
@@ -95,7 +83,12 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
if needLoad {
|
||||
if loaded.llama != nil {
|
||||
slog.Info("changing loaded model")
|
||||
unload()
|
||||
loaded.llama.Close()
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
@@ -115,19 +108,22 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
loaded.projectors = model.ProjectorPaths
|
||||
loaded.llama = llama
|
||||
loaded.Options = &opts
|
||||
|
||||
if err = llama.WaitUntilRunning(); err != nil {
|
||||
slog.Error("error loading llama server", "error", err)
|
||||
unload()
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
unload()
|
||||
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
})
|
||||
}
|
||||
|
||||
@@ -651,7 +647,7 @@ func CreateModelHandler(c *gin.Context) {
|
||||
ctx, cancel := context.WithCancel(c.Request.Context())
|
||||
defer cancel()
|
||||
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), req.Quantization, commands, fn); err != nil {
|
||||
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -917,24 +913,6 @@ func HeadBlobHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func CreateBlobHandler(c *gin.Context) {
|
||||
path, err := GetBlobsPath(c.Param("digest"))
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
_, err = os.Stat(path)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
// noop
|
||||
case err != nil:
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
default:
|
||||
c.Status(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
layer, err := NewLayer(c.Request.Body, "")
|
||||
if err != nil {
|
||||
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
@@ -1150,7 +1128,9 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
unload()
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
gpu.Cleanup()
|
||||
os.Exit(0)
|
||||
}()
|
||||
|
||||
@@ -61,7 +61,7 @@ func Test_Routes(t *testing.T) {
|
||||
fn := func(resp api.ProgressResponse) {
|
||||
t.Logf("Status: %s", resp.Status)
|
||||
}
|
||||
err = CreateModel(context.TODO(), name, "", "", commands, fn)
|
||||
err = CreateModel(context.TODO(), name, "", commands, fn)
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,83 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest represents a digest of a model Manifest. It is a comparable value
|
||||
// type and is immutable.
|
||||
//
|
||||
// The zero Digest is not a valid digest.
|
||||
type Digest struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// Type returns the digest type of the digest.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseDigest("sha256-1234").Type() // returns "sha256"
|
||||
func (d Digest) Type() string {
|
||||
typ, _, _ := strings.Cut(d.s, "-")
|
||||
return typ
|
||||
}
|
||||
|
||||
// String returns the digest in the form of "<digest-type>-<digest>", or the
|
||||
// empty string if the digest is invalid.
|
||||
func (d Digest) String() string { return d.s }
|
||||
|
||||
// IsValid returns true if the digest is valid (not zero).
|
||||
//
|
||||
// A valid digest may be created only by ParseDigest, or
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if !ok {
|
||||
typ, digest, ok = strings.Cut(s, ":")
|
||||
}
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: fmt.Sprintf("%s-%s", typ, digest)}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -1,708 +0,0 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
digest string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[PartHost],
|
||||
namespace: p.parts[PartNamespace],
|
||||
model: p.parts[PartModel],
|
||||
tag: p.parts[PartTag],
|
||||
build: p.parts[PartBuild],
|
||||
digest: p.parts[PartDigest],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
"mistral:7b": {model: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {model: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// invalid digest
|
||||
"mistral:latest@invalid256-": {},
|
||||
"mistral:latest@-123": {},
|
||||
"mistral:latest@!-123": {},
|
||||
"mistral:latest@1-!": {},
|
||||
"mistral:latest@": {},
|
||||
|
||||
// resolved
|
||||
"x@sha123-1": {model: "x", digest: "sha123-1"},
|
||||
"@sha456-2": {digest: "sha456-2"},
|
||||
|
||||
"@@sha123-1": {},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
|
||||
// invalid (includes fuzzing trophies)
|
||||
" / / : + ": {},
|
||||
" / : + ": {},
|
||||
" : + ": {},
|
||||
" + ": {},
|
||||
" : ": {},
|
||||
" / ": {},
|
||||
" /": {},
|
||||
"/ ": {},
|
||||
"/": {},
|
||||
":": {},
|
||||
"+": {},
|
||||
|
||||
// (".") in namepsace is not allowed
|
||||
"invalid.com/7b+x": {},
|
||||
|
||||
"invalid:7b+Q4_0:latest": {},
|
||||
"in valid": {},
|
||||
"invalid/y/z/foo": {},
|
||||
"/0": {},
|
||||
"0 /0": {},
|
||||
"0 /": {},
|
||||
"0/": {},
|
||||
":/0": {},
|
||||
"+0/00000": {},
|
||||
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
|
||||
"0//0": {},
|
||||
"m+^^^": {},
|
||||
"file:///etc/passwd": {},
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
":x": {},
|
||||
"+x": {},
|
||||
"x+": {},
|
||||
|
||||
// Disallow ("\.+") in any part to prevent path traversal anywhere
|
||||
// we convert the name to a path.
|
||||
"../etc/passwd": {},
|
||||
".../etc/passwd": {},
|
||||
"./../passwd": {},
|
||||
"./0+..": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
// TestConsecutiveDots tests that consecutive dots are not allowed in any
|
||||
// part, to avoid path traversal. There also are some tests in testNames, but
|
||||
// this test is more exhaustive and exists to emphasize the importance of
|
||||
// preventing path traversal.
|
||||
func TestNameConsecutiveDots(t *testing.T) {
|
||||
for i := 1; i < 10; i++ {
|
||||
s := strings.Repeat(".", i)
|
||||
if i > 1 {
|
||||
if g := ParseName(s, FillNothing).DisplayLong(); g != "" {
|
||||
t.Errorf("ParseName(%q) = %q; want empty string", s, g)
|
||||
}
|
||||
} else {
|
||||
if g := ParseName(s, FillNothing).DisplayLong(); g != s {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, g, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(NumParts), len(p.parts); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := PartKind(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
if g := kind.String(); g != name {
|
||||
t.Errorf("%s = %q; want %q", kind, g, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
name := ParseName(s, FillNothing)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(name.DisplayLong(), FillNothing).EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.DisplayLong(), baseName)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
fill string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "example.com/library/?:latest+Q4_0", "example.com/library/mistral:latest+Q4_0"},
|
||||
{"mistral", "example.com/library/?:latest", "example.com/library/mistral:latest"},
|
||||
{"llama2:x", "example.com/library/?:latest+Q4_0", "example.com/library/llama2:x+Q4_0"},
|
||||
|
||||
// Invalid
|
||||
{"", "example.com/library/?:latest+Q4_0", ""},
|
||||
{"llama2:?", "example.com/library/?:latest+Q4_0", ""},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
name := ParseName(tt.in, tt.fill)
|
||||
if g := name.DisplayLong(); g != tt.want {
|
||||
t.Errorf("ParseName(%q, %q) = %q; want %q", tt.in, tt.fill, g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("invalid fill", func(t *testing.T) {
|
||||
defer func() {
|
||||
if recover() == nil {
|
||||
t.Fatal("expected panic")
|
||||
}
|
||||
}()
|
||||
ParseName("x", "^")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseNameHTTPDoublePrefixStrip(t *testing.T) {
|
||||
cases := []string{
|
||||
"http://https://valid.com/valid/valid:latest",
|
||||
"https://http://valid.com/valid/valid:latest",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
name := ParseName(s, FillNothing)
|
||||
if name.IsValid() {
|
||||
t.Errorf("expected invalid path; got %#v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
complete bool
|
||||
completeNoBuild bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"incomplete/mistral:7b+x", false, false},
|
||||
{"incomplete/mistral:7b+Q4_0", false, false},
|
||||
{"incomplete:7b+x", false, false},
|
||||
{"complete.com/x/mistral:latest+Q4_0", true, true},
|
||||
{"complete.com/x/mistral:latest", false, true},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
}
|
||||
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
|
||||
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Complete uses Parts which returns a slice, but it should be
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0", FillNothing).IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameLogValue(t *testing.T) {
|
||||
cases := []string{
|
||||
"example.com/library/mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
"mistral:7b+Q4_0",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseName(s, FillNothing)
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("expected log output to contain %q; got %q", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameGoString(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantString string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayLongest(t *testing.T) {
|
||||
g := ParseName("example.com/library/mistral:latest+Q4_0", FillNothing).DisplayLongest()
|
||||
if g != "example.com/library/mistral:latest" {
|
||||
t.Errorf("got = %q; want %q", g, "example.com/library/mistral:latest")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDisplayShortest(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
mask string
|
||||
want string
|
||||
wantPanic bool
|
||||
}{
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/_/_:latest", "library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
{"example.com/library/mistral:latest+Q4_0", "", "example.com/library/mistral", false},
|
||||
|
||||
// case-insensitive
|
||||
{"Example.com/library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/Library/mistral:latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/Mistral:latest+Q4_0", "example.com/library/_:latest", "Mistral", false},
|
||||
{"example.com/library/mistral:Latest+Q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
{"example.com/library/mistral:Latest+q4_0", "example.com/library/_:latest", "mistral", false},
|
||||
|
||||
// zero value
|
||||
{"", MaskDefault, "", true},
|
||||
|
||||
// invalid mask
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/mistral", "", true},
|
||||
|
||||
// DefaultMask
|
||||
{"registry.ollama.ai/library/mistral:latest+Q4_0", MaskDefault, "mistral", false},
|
||||
|
||||
// Auto-Fill
|
||||
{"x", "example.com/library/_:latest", "x", false},
|
||||
{"x", "example.com/library/_:latest+Q4_0", "x", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
{"x/y:z", "a.com/library/_:latest+Q4_0", "x/y:z", false},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
defer func() {
|
||||
if tt.wantPanic {
|
||||
if recover() == nil {
|
||||
t.Errorf("expected panic")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.DisplayShortest(tt.mask); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0", FillNothing))
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzParseNameFromFilepath(f *testing.F) {
|
||||
f.Add("example.com/library/mistral/7b/Q4_0")
|
||||
f.Add("example.com/../mistral/7b/Q4_0")
|
||||
f.Add("example.com/x/../7b/Q4_0")
|
||||
f.Add("example.com/x/../7b")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
name := ParseNameFromFilepath(s, FillNothing)
|
||||
if strings.Contains(s, "..") && !name.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
}
|
||||
if name.IsValid() == name.IsZero() {
|
||||
t.Errorf("expected valid path to be non-zero value; got %#v", name)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Add(":x")
|
||||
f.Add("@sha256-123456")
|
||||
f.Add("example.com/mistral:latest+Q4_0@sha256-123456")
|
||||
f.Add(":@!@")
|
||||
f.Add("...")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s, FillNothing)
|
||||
|
||||
if strings.Contains(s, "..") && !r0.IsZero() {
|
||||
t.Fatalf("non-zero value for path with '..': %q", s)
|
||||
}
|
||||
|
||||
if !r0.IsValid() && !r0.IsResolved() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.parts {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.DisplayLong(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.DisplayLong(), s)
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.DisplayLong(), FillNothing)
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0", FillNothing)
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.DisplayLong())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest"},
|
||||
|
||||
// incomplete
|
||||
{"example.com/library/mistral:latest", "example.com/library/mistral:latest"},
|
||||
{"", ""},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.URLPath(); g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameFilepath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantNoBuild string
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "Example.Com/Library/Mistral:Latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "Example.Com/Library/Mistral:Latest+Q4_0",
|
||||
want: "example.com/library/mistral/latest/Q4_0",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral:latest",
|
||||
want: "example.com/library/mistral/latest",
|
||||
wantNoBuild: "example.com/library/mistral/latest",
|
||||
},
|
||||
{
|
||||
in: "",
|
||||
want: "",
|
||||
wantNoBuild: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in, FillNothing)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
g := p.Filepath()
|
||||
g = filepath.ToSlash(g)
|
||||
if g != tt.want {
|
||||
t.Errorf("got = %q; want %q", g, tt.want)
|
||||
}
|
||||
g = p.FilepathNoBuild()
|
||||
g = filepath.ToSlash(g)
|
||||
if g != tt.wantNoBuild {
|
||||
t.Errorf("got = %q; want %q", g, tt.wantNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFilepath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
fill string // default is FillNothing
|
||||
want string
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral/latest/Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/latest",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "example.com/library",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/^/mistral/latest/Q4_0",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/../Q4_0",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "example.com/library/mistral/latest/Q4_0/extra",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
in := strings.ReplaceAll(tt.in, "/", string(filepath.Separator))
|
||||
fill := cmp.Or(tt.fill, FillNothing)
|
||||
want := ParseName(tt.want, fill)
|
||||
if g := ParseNameFromFilepath(in, fill); !g.EqualFold(want) {
|
||||
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameFromPath(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
fill string // default is FillNothing
|
||||
}{
|
||||
{
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral",
|
||||
want: "example.com/library/mistral",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library/mistral",
|
||||
fill: "?/?/?:latest+Q4_0",
|
||||
want: "example.com/library/mistral:latest+Q4_0",
|
||||
},
|
||||
{
|
||||
in: "/example.com/library",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "/example.com/",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
in: "/example.com/^/mistral/latest",
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
fill := cmp.Or(tt.fill, FillNothing)
|
||||
if g := ParseNameFromURLPath(tt.in, fill); g.DisplayLong() != tt.want {
|
||||
t.Errorf("got = %q; want %q", g.DisplayLong(), tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseName("mistral:latest+q4", FillNothing).MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4", FillNothing).MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4", FillNothing).MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseName("mistral:LATest", FillNothing).MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseName("mistral:latest", FillNothing),
|
||||
ParseName("mistRal:7b+q4", FillNothing),
|
||||
ParseName("MIstral:7b", FillNothing),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n.DisplayLong())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// MIstral:7b
|
||||
// mistRal:7b+q4
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func ExampleName_completeAndResolved() {
|
||||
for _, s := range []string{
|
||||
"x/y/z:latest+q4_0@sha123-1",
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
name := ParseName(s, FillNothing)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", name.IsComplete(), name.IsResolved(), name.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// complete:true resolved:true digest:sha123-1
|
||||
// complete:true resolved:false digest:
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayShortest() {
|
||||
name := ParseName("example.com/jmorganca/mistral:latest+Q4_0", FillNothing)
|
||||
|
||||
fmt.Println(name.DisplayShortest("example.com/jmorganca/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:latest"))
|
||||
fmt.Println(name.DisplayShortest("example.com/_/_:_"))
|
||||
fmt.Println(name.DisplayShortest("_/_/_:_"))
|
||||
|
||||
// Default
|
||||
name = ParseName("registry.ollama.ai/library/mistral:latest+Q4_0", FillNothing)
|
||||
fmt.Println(name.DisplayShortest(""))
|
||||
|
||||
// Output:
|
||||
// mistral
|
||||
// jmorganca/mistral
|
||||
// jmorganca/mistral:latest
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
||||
113
x/api/api.go
Normal file
113
x/api/api.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
regtype "github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
// Common API Errors
|
||||
var (
|
||||
errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
|
||||
errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Build *build.Server
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
oweb.Serve(s.serveHTTP, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func want(r *http.Request, method, path string) bool {
|
||||
return r.Method == method && r.URL.Path == path
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return oweb.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Name == "" {
|
||||
return oweb.Missing("name")
|
||||
}
|
||||
|
||||
const registryURLTODO = "http://localhost:8888"
|
||||
|
||||
man, err := s.Build.ManifestData(params.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, build.ErrNotFound) {
|
||||
return errRefNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c := registry.Client{BaseURL: registryURLTODO}
|
||||
requirements, err := c.Push(r.Context(), params.Name, man, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploads []regtype.CompletePart
|
||||
for _, rq := range requirements {
|
||||
l, err := s.Build.LayerFile(rq.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = func() error {
|
||||
f, err := os.Open(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
cp, err := registry.PushLayer(r.Context(), f, rq.URL, rq.Offset, rq.Size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = append(uploads, cp)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// commit the manifest to the registry
|
||||
requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{
|
||||
CompleteParts: uploads,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range requirements {
|
||||
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
209
x/build/build.go
Normal file
209
x/build/build.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrIncompleteRef = errors.New("unqualified ref")
|
||||
ErrBuildPresentInRef = errors.New("build present in ref")
|
||||
ErrUnsupportedModelFormat = errors.New("unsupported model format")
|
||||
ErrMissingFileType = errors.New("missing 'general.file_type' key")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
type mediaType string
|
||||
|
||||
// Known media types
|
||||
const (
|
||||
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
st *blobstore.Store
|
||||
}
|
||||
|
||||
// Open starts a new build server that uses dir as the base directory for all
|
||||
// build artifacts. If dir is empty, DefaultDir is used.
|
||||
//
|
||||
// It returns an error if the provided or default dir cannot be initialized.
|
||||
func Open(dir string) (*Server, error) {
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = DefaultDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
st, err := blobstore.Open(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{st: st}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Build(ref string, f model.File) error {
|
||||
mp := model.ParseName(ref)
|
||||
if !mp.IsCompleteNoBuild() {
|
||||
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
|
||||
}
|
||||
|
||||
// 1. Resolve FROM
|
||||
// a. If it's a local file (gguf), hash it and add it to the store.
|
||||
// c. If it's a remote file (http), refuse.
|
||||
// 2. Turn other pragmas into layers, and add them to the store.
|
||||
// 3. Create a manifest from the layers.
|
||||
// 4. Store the manifest in the manifest cache
|
||||
// 5. Done.
|
||||
|
||||
if f.From == "" {
|
||||
return &model.FileError{Pragma: "FROM", Message: "missing"}
|
||||
}
|
||||
|
||||
var layers []layerJSON
|
||||
|
||||
id, info, size, err := s.importModel(f.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: mediaTypeModel,
|
||||
Size: size,
|
||||
})
|
||||
|
||||
id, size, err = blobstore.PutString(s.st, f.License)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: "text/plain",
|
||||
Size: size,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(manifestJSON{Layers: layers})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setManifestData(
|
||||
mp.WithBuild(info.FileType.String()),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) LayerFile(digest string) (string, error) {
|
||||
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
|
||||
_, err := os.Stat(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
|
||||
}
|
||||
return fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) ManifestData(ref string) ([]byte, error) {
|
||||
data, _, err := s.resolve(model.ParseName(ref))
|
||||
return data, err
|
||||
}
|
||||
|
||||
// WeightFile returns the absolute path to the weights file for the given model ref.
|
||||
func (s *Server) WeightsFile(ref string) (string, error) {
|
||||
m, err := s.getManifest(model.ParseName(ref))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if l.MediaType == mediaTypeModel {
|
||||
return s.st.OutputFilename(l.ID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing weights layer for %q", ref)
|
||||
}
|
||||
|
||||
// resolve returns the data for the given ref, if any.
|
||||
//
|
||||
// TODO: This should ideally return an ID, but the current on
|
||||
// disk layout is that the actual manifest is stored in the "ref" instead of
|
||||
// a pointer to a content-addressed blob. I (bmizerany) think we should
|
||||
// change the on-disk layout to store the manifest in a content-addressed
|
||||
// blob, and then have the ref point to that blob. This would simplify the
|
||||
// code, allow us to have integrity checks on the manifest, and clean up
|
||||
// this interface.
|
||||
func (s *Server) resolve(ref model.Name) (data []byte, fileName string, err error) {
|
||||
fileName, err = s.refFileName(ref)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err = os.ReadFile(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
|
||||
}
|
||||
if err != nil {
|
||||
// do not wrap the error here, as it is likely an I/O error
|
||||
// and we want to preserve the absraction since we may not
|
||||
// be on disk later.
|
||||
return nil, "", fmt.Errorf("manifest read error: %v", err)
|
||||
}
|
||||
return data, fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) SetManifestData(ref string, data []byte) error {
|
||||
return s.setManifestData(model.ParseName(ref), data)
|
||||
}
|
||||
|
||||
// Set sets the data for the given ref.
|
||||
func (s *Server) setManifestData(mp model.Name, data []byte) error {
|
||||
path, err := s.refFileName(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refFileName(mp model.Name) (string, error) {
|
||||
if !mp.IsComplete() {
|
||||
return "", fmt.Errorf("ref not fully qualified: %q", mp)
|
||||
}
|
||||
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(mp.Parts()...)), nil
|
||||
}
|
||||
|
||||
type manifestJSON struct {
|
||||
// Layers is the list of layers in the manifest.
|
||||
Layers []layerJSON `json:"layers"`
|
||||
}
|
||||
|
||||
// Layer is a layer in a model manifest.
|
||||
type layerJSON struct {
|
||||
// ID is the ID of the layer.
|
||||
ID blobstore.ID `json:"digest"`
|
||||
MediaType mediaType `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (s *Server) getManifest(ref model.Name) (manifestJSON, error) {
|
||||
data, path, err := s.resolve(ref)
|
||||
if err != nil {
|
||||
return manifestJSON{}, err
|
||||
}
|
||||
var m manifestJSON
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
163
x/build/build_test.go
Normal file
163
x/build/build_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
const qualifiedRef = "x/y/z:latest+Q4_0"
|
||||
|
||||
func TestServerBuildErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("unqualified ref", func(t *testing.T) {
|
||||
err := s.Build("x", model.File{})
|
||||
if !errors.Is(err, ErrIncompleteRef) {
|
||||
t.Fatalf("Build() err = %v; want unqualified ref", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM pragma missing", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{})
|
||||
var e *model.FileError
|
||||
if !errors.As(err, &e) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if e.Pragma != "FROM" {
|
||||
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
|
||||
}
|
||||
if e.Message != "missing" {
|
||||
t.Errorf("e.Message = %s; want missing", e.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM file not found", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{From: "bar"})
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Build() err = %v; want file not found", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM gguf", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
// Write a gguf file without general.file_type metadata.
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"",
|
||||
)
|
||||
|
||||
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
|
||||
if !errors.Is(err, ErrMissingFileType) {
|
||||
t.Fatalf("Build() err = %#v; want missing file type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM obscure dir", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.mkdirAll("unknown")
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM unsupported model type", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
from := w.write("unknown", "unknown content")
|
||||
err := s.Build(qualifiedRef, model.File{From: from})
|
||||
if !errors.Is(err, ErrUnsupportedModelFormat) {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBasicGGUF(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
|
||||
"general.file_type"+ // key
|
||||
"\x04\x00\x00\x00"+ // type (uint32)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
|
||||
"",
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
t.Logf("file: %s", p)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("WeightsFile() err = %v; want not found", err)
|
||||
}
|
||||
|
||||
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info, err := gguf.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.FileType != gguf.TypeQ4_0 {
|
||||
t.Errorf("info.FileType = %d; want 1", info.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
type work struct {
|
||||
t testing.TB
|
||||
dir string
|
||||
}
|
||||
|
||||
func newWorkDir(t *testing.T) work {
|
||||
return work{t: t, dir: t.TempDir()}
|
||||
}
|
||||
|
||||
func (w work) write(name, content string) (path string) {
|
||||
w.t.Helper()
|
||||
path = w.fileName(name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (w work) fileName(name string) string {
|
||||
w.t.Helper()
|
||||
return filepath.Join(w.dir, name)
|
||||
}
|
||||
|
||||
func (w work) mkdirAll(path string) {
|
||||
w.t.Helper()
|
||||
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
}
|
||||
12
x/build/convert.go
Normal file
12
x/build/convert.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package build
|
||||
|
||||
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
|
||||
// TODO: decine on hueristic for converting safetensor to gguf and
|
||||
// the errors that can be returned. For now, we just say
|
||||
// "unsupported", however it may be intended to be a valid safe
|
||||
// tensor but we hit an error in the conversion.
|
||||
//
|
||||
// I (bmizernay) think this will naturally evolve as we implement
|
||||
// the conversion.
|
||||
return "", ErrUnsupportedModelFormat
|
||||
}
|
||||
28
x/build/default.go
Normal file
28
x/build/default.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDir = sync.OnceValues(func() (string, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return dir, nil
|
||||
})
|
||||
)
|
||||
|
||||
// DefaultDir returns the default directory for models. It returns the value
|
||||
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
|
||||
// "$HOME/.ollama/models".
|
||||
func DefaultDir() (string, error) {
|
||||
return defaultDir()
|
||||
}
|
||||
59
x/build/import.go
Normal file
59
x/build/import.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
|
||||
return blobstore.ID{}, gguf.Info{}, 0, err
|
||||
}
|
||||
|
||||
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return s.importSafeTensor(path)
|
||||
} else {
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := gguf.StatReader(f)
|
||||
if errors.Is(err, gguf.ErrBadMagic) {
|
||||
return importError(ErrUnsupportedModelFormat)
|
||||
}
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
|
||||
if info.FileType == 0 {
|
||||
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
|
||||
}
|
||||
id, size, err := s.st.Put(f)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return id, info, size, nil
|
||||
}
|
||||
|
||||
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
path, err := convertSafeTensorToGGUF(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
329
x/build/internal/blobstore/blob.go
Normal file
329
x/build/internal/blobstore/blob.go
Normal file
@@ -0,0 +1,329 @@
|
||||
// Package blobstore implements a blob store.
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidID = errors.New("invalid ID")
|
||||
)
|
||||
|
||||
const HashSize = 32
|
||||
|
||||
// An ID is a blob output key, the hash of an output of a computation.
|
||||
type ID struct {
|
||||
a [HashSize]byte
|
||||
}
|
||||
|
||||
func (id ID) MarshalText() ([]byte, error) {
|
||||
return []byte(id.String()), nil
|
||||
}
|
||||
|
||||
func (id *ID) UnmarshalText(text []byte) error {
|
||||
*id = ParseID(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseID(s string) ID {
|
||||
const prefix = "sha256-"
|
||||
h, ok := strings.CutPrefix(s, prefix)
|
||||
if !ok {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
if len(h) != HashSize*2 {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var b []byte
|
||||
_, err := fmt.Sscanf(h, "%x", &b)
|
||||
if err != nil {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var id ID
|
||||
copy(id.a[:], b)
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
if !id.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", id.a[:])
|
||||
}
|
||||
|
||||
func (id ID) Valid() bool {
|
||||
return id != ID{}
|
||||
}
|
||||
|
||||
func (id ID) Match(h [HashSize]byte) bool {
|
||||
return id.a == h
|
||||
}
|
||||
|
||||
// A Store is a blob store, backed by a file system directory tree.
|
||||
type Store struct {
|
||||
dir string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// Open opens and returns the store in the given directory.
|
||||
//
|
||||
// It is safe for multiple processes on a single machine to use the
|
||||
// same store directory in a local file system simultaneously.
|
||||
// They will coordinate using operating system file locks and may
|
||||
// duplicate effort but will not corrupt the store.
|
||||
//
|
||||
// However, it is NOT safe for multiple processes on different machines
|
||||
// to share a store directory (for example, if the directory were stored
|
||||
// in a network file system). File locking is notoriously unreliable in
|
||||
// network file systems and may not suffice to protect the store.
|
||||
func Open(dir string) (*Store, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Store{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Store) Dir() string {
|
||||
return s.dir
|
||||
}
|
||||
|
||||
// fileName returns the name of the blob file corresponding to the given id.
|
||||
func (s *Store) fileName(id ID) string {
|
||||
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
|
||||
}
|
||||
|
||||
// An entryNotFoundError indicates that a store entry was not found, with an
|
||||
// optional underlying reason.
|
||||
type entryNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Error() string {
|
||||
if e.Err == nil {
|
||||
return "store entry not found"
|
||||
}
|
||||
return fmt.Sprintf("store entry not found: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
ID ID
|
||||
Size int64
|
||||
Time time.Time // when added to store
|
||||
}
|
||||
|
||||
// GetFile looks up the blob ID in the store and returns
|
||||
// the name of the corresponding data file.
|
||||
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
|
||||
entry, err = s.Get(id)
|
||||
if err != nil {
|
||||
return "", Entry{}, err
|
||||
}
|
||||
file = s.OutputFilename(entry.ID)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
if info.Size() != entry.Size {
|
||||
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
|
||||
}
|
||||
return file, entry, nil
|
||||
}
|
||||
|
||||
// GetBytes looks up the blob ID in the store and returns
|
||||
// the corresponding output bytes.
|
||||
// GetBytes should only be used for data that can be expected to fit in memory.
|
||||
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
|
||||
entry, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, entry, err
|
||||
}
|
||||
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
|
||||
if entry.ID.Match(sha256.Sum256(data)) {
|
||||
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
|
||||
}
|
||||
return data, entry, nil
|
||||
}
|
||||
|
||||
// OutputFilename returns the name of the blob file for the given ID.
|
||||
func (s *Store) OutputFilename(id ID) string {
|
||||
file := s.fileName(id)
|
||||
// TODO(bmizerany): touch as "used" for cache trimming. (see
|
||||
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
|
||||
return file
|
||||
}
|
||||
|
||||
// Get looks up the blob ID in the store,
|
||||
// returning the corresponding output ID and file size, if any.
|
||||
// Note that finding an output ID does not guarantee that the
|
||||
// saved file for that output ID is still available.
|
||||
func (s *Store) Get(id ID) (Entry, error) {
|
||||
file := s.fileName(id)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
return Entry{
|
||||
ID: id,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
// TODO(bmizerany): return c.Trim()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores the data read from the given file into the store as ID.
|
||||
//
|
||||
// It may read file twice. The content of file must not change between the
|
||||
// two passes.
|
||||
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
|
||||
return s.put(file)
|
||||
}
|
||||
|
||||
func PutBytes(s *Store, data []byte) (ID, int64, error) {
|
||||
return s.Put(bytes.NewReader(data))
|
||||
}
|
||||
|
||||
func PutString(s *Store, data string) (ID, int64, error) {
|
||||
return s.Put(strings.NewReader(data))
|
||||
}
|
||||
|
||||
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
|
||||
// Compute output ID.
|
||||
h := sha256.New()
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
size, err := io.Copy(h, file)
|
||||
if err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
var out ID
|
||||
h.Sum(out.a[:0])
|
||||
|
||||
// Copy to blob file (if not already present).
|
||||
if err := s.copyFile(file, out, size); err != nil {
|
||||
return out, size, err
|
||||
}
|
||||
|
||||
// TODO: Add to manifest index.
|
||||
return out, size, nil
|
||||
}
|
||||
|
||||
// copyFile copies file into the store, expecting it to have the given
|
||||
// output ID and size, if that file is not present already.
|
||||
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
|
||||
name := s.fileName(out)
|
||||
println("name", name)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// Check hash.
|
||||
if f, err := os.Open(name); err == nil {
|
||||
h := sha256.New()
|
||||
io.Copy(h, f)
|
||||
f.Close()
|
||||
var out2 ID
|
||||
h.Sum(out2.a[:0])
|
||||
if out == out2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Hash did not match. Fall through and rewrite file.
|
||||
}
|
||||
|
||||
// Copy file to blobs directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h := sha256.New()
|
||||
w := io.MultiWriter(f, h)
|
||||
if _, err := io.CopyN(w, file, size-1); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
// Check last byte before writing it; writing it will make the size match
|
||||
// what other processes expect to find and might cause them to start
|
||||
// using the file.
|
||||
buf := make([]byte, 1)
|
||||
if _, err := file.Read(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h.Write(buf)
|
||||
sum := h.Sum(nil)
|
||||
if !bytes.Equal(sum, out.a[:]) {
|
||||
f.Truncate(0)
|
||||
return fmt.Errorf("file content changed underfoot")
|
||||
}
|
||||
|
||||
// Commit manifest entry.
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove stored file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, s.now(), s.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
||||
54
x/build/internal/blobstore/blob_test.go
Normal file
54
x/build/internal/blobstore/blob_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
var invalid = strings.Repeat("\x00", HashSize*2)
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", invalid},
|
||||
{"sha256-", invalid},
|
||||
{"sha256-" + valid, valid},
|
||||
|
||||
{"" + valid, invalid}, // no prefix
|
||||
{"sha123-" + valid, invalid}, // invalid prefix
|
||||
{"sha256-" + valid[1:], invalid}, // too short
|
||||
{"sha256-" + valid + "a", invalid}, // too long
|
||||
{"sha256-!" + valid[1:], invalid}, // invalid hex
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
// sanity check
|
||||
if len(tt.want) > HashSize*2 {
|
||||
panic("invalid test")
|
||||
}
|
||||
|
||||
got := ParseID(tt.in)
|
||||
|
||||
wantValid := tt.want != invalid
|
||||
if wantValid {
|
||||
if !got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
|
||||
}
|
||||
if got.String() != "sha256-"+tt.want {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
|
||||
}
|
||||
if got.String() != "" {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
128
x/build/internal/blobstore/store_test.go
Normal file
128
x/build/internal/blobstore/store_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const (
|
||||
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
func TestStoreBasicBlob(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
checkDir(t, dir, nil)
|
||||
|
||||
st, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
st.now = func() time.Time { return now }
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
})
|
||||
|
||||
id, size, err := PutBytes(st, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ParseID(blobNameHello) {
|
||||
t.Errorf("unexpected ID: %s", id)
|
||||
}
|
||||
if size != 5 {
|
||||
t.Errorf("unexpected size: %d", size)
|
||||
}
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
"blobs/" + blobNameHello,
|
||||
})
|
||||
|
||||
got, err := st.Get(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, Entry{
|
||||
ID: id,
|
||||
Size: 5,
|
||||
Time: now,
|
||||
})
|
||||
|
||||
file := st.OutputFilename(id)
|
||||
wantFile := filepath.Join(dir, "blobs", blobNameHello)
|
||||
if file != wantFile {
|
||||
t.Errorf("unexpected file: %s", file)
|
||||
}
|
||||
|
||||
// Check tags
|
||||
name := model.ParseName("registry.ollama.ai/library/test:latest+KQED")
|
||||
|
||||
t.Logf("RESOLVING: %q", name.Parts())
|
||||
|
||||
}
|
||||
|
||||
// checkDir checks that the directory at dir contains the files in want. The
|
||||
// files in want must be relative to dir.
|
||||
//
|
||||
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
|
||||
//
|
||||
// want must be in lexicographic order.
|
||||
func checkDir(t testing.TB, dir string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
var matches []string
|
||||
for path, err := range walkDir(dir) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("found %s", path)
|
||||
if path == "./" {
|
||||
continue
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
matches = append(matches, path)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, matches, want)
|
||||
}
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
func walkDir(dir string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err = filepath.Rel(dir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
if info.IsDir() {
|
||||
path += "/"
|
||||
}
|
||||
if !yield(path, nil) {
|
||||
return errStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, errStop) && err != nil {
|
||||
yield("", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
31
x/client/ollama/apitype/apitype.go
Normal file
31
x/client/ollama/apitype/apitype.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package apitype
|
||||
|
||||
import "time"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Ref string `json:"ref"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt int64 `json:"modified"`
|
||||
}
|
||||
|
||||
func (m Model) Modifed() time.Time {
|
||||
return time.Unix(0, m.ModifiedAt)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
|
||||
Insecure bool `json:"insecure"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type PushStatus struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
173
x/client/ollama/ollama.go
Normal file
173
x/client/ollama/ollama.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/types/empty"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): PROGRESS INDICATORS!!!!
|
||||
|
||||
const DefaultBaseURL = "http://localhost:11434"
|
||||
|
||||
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
|
||||
|
||||
// Default returns a new client with the default base URL.
|
||||
func Default() *Client {
|
||||
return &Client{BaseURL: envBaseURL}
|
||||
}
|
||||
|
||||
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
|
||||
// true for any instance of Client to work.
|
||||
var I_Acknowledge_This_API_Is_Under_Development bool
|
||||
|
||||
// Client is a client for the Ollama API.
|
||||
type Client struct {
|
||||
// BaseURL is the base URL of the Ollama API.
|
||||
BaseURL string
|
||||
|
||||
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
|
||||
}
|
||||
|
||||
// Build requests the remote Ollama service to build a model. It uploads any
|
||||
// source files the server needs.
|
||||
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Push requests the remote Ollama service to push a model to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string) error {
|
||||
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Remove(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
// Status is the HTTP status code returned by the server.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Code specifies a machine readable code indicating the class of
|
||||
// error this error is. See http://docs.ollama.com/errors for a full
|
||||
// list of error codes.
|
||||
Code string `json:"code"`
|
||||
|
||||
// Message is a humage readable message that describes the error. It
|
||||
// may change across versions of the API, so it should not be used for
|
||||
// programmatic decisions.
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Field is the field in the request that caused the error, if any.
|
||||
Field string `json:"field,omitempty"`
|
||||
|
||||
// Value is the value of the field that caused the error, if any.
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("ollama: ")
|
||||
b.WriteString(e.Code)
|
||||
if e.Field != "" {
|
||||
b.WriteString(" ")
|
||||
b.WriteString(e.Field)
|
||||
}
|
||||
if e.Value != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Value)
|
||||
}
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Do encodes in and sends it in a request to the Ollama server and decodes
|
||||
// the response into Res, or an error response (non-2xx) into an *Error, or
|
||||
// any error encounted decoding the response.
|
||||
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
|
||||
var body bytes.Buffer
|
||||
// TODO(bmizerany): pool and reuse this buffer AND the encoder
|
||||
if err := encodeJSON(&body, in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr := c.BaseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
var buf bytes.Buffer
|
||||
body := io.TeeReader(res.Body, &buf)
|
||||
e, err := decodeJSON[Error](body)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
|
||||
return nil, err
|
||||
}
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return decodeJSON[Res](res.Body)
|
||||
}
|
||||
|
||||
// decodeJSON decodes JSON from r into a new value of type T.
|
||||
//
|
||||
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
|
||||
// do not try and consolidate so we can keep ollama/client free from
|
||||
// dependencies which are moving targets and not pulling enough weight to
|
||||
// justify their inclusion.
|
||||
func decodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// NOTE: see NOT above decodeJSON
|
||||
func encodeJSON(w io.Writer, v any) error {
|
||||
// TODO(bmizerany): pool and reuse encoder
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
100
x/cmd/bllamo/bllamo.go
Normal file
100
x/cmd/bllamo/bllamo.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Bllamo is a (new) tool for managing Ollama models.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// bllamo <command> [arguments]
|
||||
//
|
||||
// The commands are:
|
||||
//
|
||||
// build build a model from a Modelfile
|
||||
// list list all models
|
||||
// push push a model from an ollama registry
|
||||
// pull pull a model from an ollama registry
|
||||
// delete delete a model from an ollama registry
|
||||
// help display help for a command
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/api"
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := Main(flag.Args()...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var TODOUsage = fmt.Errorf("TODO: usage")
|
||||
|
||||
var commands = map[string]func(ctx context.Context, args ...string) error{
|
||||
"build": cmdBuild,
|
||||
"push": cmdPush,
|
||||
"serve": cmdServe,
|
||||
"registry": cmdRegistry,
|
||||
}
|
||||
|
||||
// Main is the entry point for the blammo command.
|
||||
func Main(args ...string) error {
|
||||
cmd := args[0]
|
||||
args = args[1:]
|
||||
if f, ok := commands[cmd]; ok {
|
||||
ctx := context.TODO()
|
||||
return f(ctx, args...)
|
||||
}
|
||||
return fmt.Errorf("blammo: unknown command %q", cmd)
|
||||
}
|
||||
|
||||
func cmdBuild(ctx context.Context, args ...string) error {
|
||||
var v struct {
|
||||
Modelfile string `flag:"f,the Modelfile to use"`
|
||||
}
|
||||
|
||||
fs := readFlags("build", args, &v)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
|
||||
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
|
||||
}
|
||||
|
||||
func cmdRegistry(_ context.Context, _ ...string) error {
|
||||
var s registry.Server
|
||||
return http.ListenAndServe(":8888", &s)
|
||||
}
|
||||
|
||||
func cmdServe(ctx context.Context, args ...string) error {
|
||||
bs, err := build.Open("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return http.ListenAndServe(":11434", &api.Server{Build: bs})
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, args ...string) error {
|
||||
fs := readFlags("push", args, nil)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
return ollama.Default().Push(ctx, fs.Arg(0))
|
||||
}
|
||||
59
x/cmd/bllamo/flags.go
Normal file
59
x/cmd/bllamo/flags.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseArgs parses the provided args using a flag.FlagSet that is
|
||||
// dynamically build using reflection for the provided type. The type fields
|
||||
// that have a "flag" tag are used to build the flags. The flag tag should
|
||||
// include either a ('-'). Example usage:
|
||||
//
|
||||
// func main() {
|
||||
// var flags struct {
|
||||
// Modelfile string `flag:"f,path to the Modelfile"`
|
||||
// }
|
||||
//
|
||||
// fs := readFlags(os.Args[1:], &flags)
|
||||
// fs.Parse(os.Args[1:])
|
||||
// }
|
||||
func readFlags(name string, args []string, v any) *flag.FlagSet {
|
||||
fs := flag.NewFlagSet(name, flag.ExitOnError)
|
||||
defer fs.Parse(args)
|
||||
if v == nil {
|
||||
return fs
|
||||
}
|
||||
|
||||
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
|
||||
f := reflect.ValueOf(v).Field(i)
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Type().Field(i).Tag.Get("flag")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
var name, usage string
|
||||
if i := strings.Index(tag, ","); i != -1 {
|
||||
name = tag[:i]
|
||||
usage = tag[i+1:]
|
||||
} else {
|
||||
name = tag
|
||||
}
|
||||
|
||||
// TODO(bmizerany): add more types as needed
|
||||
switch f.Kind() {
|
||||
case reflect.String:
|
||||
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
|
||||
case reflect.Bool:
|
||||
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
|
||||
}
|
||||
}
|
||||
return fs
|
||||
}
|
||||
97
x/cmd/gguf/gguf.go
Normal file
97
x/cmd/gguf/gguf.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Gguf is a tool for learning about GGUF files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// gguf [flags] <file>
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Main(stdout io.Writer, args ...string) error {
|
||||
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
|
||||
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
|
||||
|
||||
fs.Usage = func() {
|
||||
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "Usage:\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "\tgguf [flags] <file>\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
var numFlags int
|
||||
fs.VisitAll(func(*flag.Flag) { numFlags++ })
|
||||
if numFlags > 0 {
|
||||
io.WriteString(stdout, "Flags:\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
fs.Parse(args)
|
||||
|
||||
if fs.NArg() != 1 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
file := fs.Arg(0)
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, err := gguf.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
|
||||
defer tw.Flush()
|
||||
|
||||
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
|
||||
|
||||
for m, err := range g.Metadata {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(m.Values) > 5 {
|
||||
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
|
||||
} else {
|
||||
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
|
||||
}
|
||||
}
|
||||
|
||||
var i int
|
||||
var totalLayerBytes uint64
|
||||
var offGPU bool
|
||||
for t, err := range g.Tensors {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
totalLayerBytes += t.Size
|
||||
if totalLayerBytes > *flagGPU {
|
||||
offGPU = true
|
||||
}
|
||||
|
||||
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
|
||||
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
|
||||
|
||||
i++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
376
x/encoding/gguf/gguf.go
Normal file
376
x/encoding/gguf/gguf.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
|
||||
|
||||
// MaxDimensions is the maximum number of dimensions a tensor can have.
|
||||
const MaxDimensions uint32 = 1e6
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrBadMagic is returned when the magic bytes at the start of the
|
||||
// file. This is useful for detecting if the file is not a gguf
|
||||
// file.
|
||||
ErrBadMagic = errors.New("gguf: bad magic")
|
||||
|
||||
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
|
||||
ErrMangled = errors.New("gguf: mangled data")
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeF32 Type = 0
|
||||
TypeF16 Type = 1
|
||||
TypeQ4_0 Type = 2
|
||||
TypeQ4_1 Type = 3
|
||||
TypeQ5_0 Type = 6
|
||||
TypeQ5_1 Type = 7
|
||||
TypeQ8_0 Type = 8
|
||||
TypeQ8_1 Type = 9
|
||||
TypeQ2_K Type = 10
|
||||
TypeQ3_K Type = 11
|
||||
TypeQ4_K Type = 12
|
||||
TypeQ5_K Type = 13
|
||||
TypeQ6_K Type = 14
|
||||
TypeQ8_K Type = 15
|
||||
TypeI8 Type = 16
|
||||
TypeI16 Type = 17
|
||||
TypeI32 Type = 18
|
||||
TypeCount Type = 19
|
||||
)
|
||||
|
||||
var typeNames = map[Type]string{
|
||||
TypeF32: "F32",
|
||||
TypeF16: "F16",
|
||||
TypeQ4_0: "Q4_0",
|
||||
TypeQ4_1: "Q4_1",
|
||||
TypeQ5_0: "Q5_0",
|
||||
TypeQ5_1: "Q5_1",
|
||||
TypeQ8_0: "Q8_0",
|
||||
TypeQ8_1: "Q8_1",
|
||||
TypeQ2_K: "Q2_K",
|
||||
TypeQ3_K: "Q3_K",
|
||||
TypeQ4_K: "Q4_K",
|
||||
TypeQ5_K: "Q5_K",
|
||||
TypeQ6_K: "Q6_K",
|
||||
TypeQ8_K: "Q8_K",
|
||||
TypeI8: "I8",
|
||||
TypeI16: "I16",
|
||||
TypeI32: "I32",
|
||||
TypeCount: "COUNT",
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
if name := typeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_type %d!)", t)
|
||||
}
|
||||
|
||||
// ValueType is the type of a metadata value.
|
||||
type ValueType uint32
|
||||
|
||||
func (t ValueType) String() string {
|
||||
if name := metaTypeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_value_type %d!)", t)
|
||||
}
|
||||
|
||||
const (
|
||||
ValueTypeUint8 ValueType = 0
|
||||
ValueTypeInt8 ValueType = 1
|
||||
ValueTypeUint16 ValueType = 2
|
||||
ValueTypeInt16 ValueType = 3
|
||||
ValueTypeUint32 ValueType = 4
|
||||
ValueTypeInt32 ValueType = 5
|
||||
ValueTypeFloat32 ValueType = 6
|
||||
ValueTypeBool ValueType = 7
|
||||
ValueTypeString ValueType = 8
|
||||
ValueTypeArray ValueType = 9
|
||||
ValueTypeUint64 ValueType = 10
|
||||
ValueTypeInt64 ValueType = 11
|
||||
ValueTypeFloat64 ValueType = 12
|
||||
)
|
||||
|
||||
var metaTypeNames = map[ValueType]string{
|
||||
ValueTypeUint8: "uint8",
|
||||
ValueTypeInt8: "int8",
|
||||
ValueTypeUint16: "uint16",
|
||||
ValueTypeInt16: "int16",
|
||||
ValueTypeUint32: "uint32",
|
||||
ValueTypeInt32: "int32",
|
||||
ValueTypeFloat32: "float32",
|
||||
ValueTypeBool: "bool",
|
||||
ValueTypeString: "string",
|
||||
ValueTypeArray: "array",
|
||||
ValueTypeUint64: "uint64",
|
||||
ValueTypeInt64: "int64",
|
||||
ValueTypeFloat64: "float64",
|
||||
}
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Dimensions []uint64
|
||||
Type Type
|
||||
Offset uint64
|
||||
Size uint64
|
||||
}
|
||||
|
||||
type MetaValue struct {
|
||||
Type ValueType
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (v MetaValue) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(v.Type.String())
|
||||
b.WriteString("(")
|
||||
switch v.Type {
|
||||
case ValueTypeArray:
|
||||
b.WriteString("[...]")
|
||||
case ValueTypeString:
|
||||
b.WriteString(strconv.Quote(string(v.Value)))
|
||||
case ValueTypeBool:
|
||||
if len(v.Value) == 0 {
|
||||
b.WriteString("(!invalid bool)")
|
||||
}
|
||||
switch v.Value[0] {
|
||||
case 0:
|
||||
b.WriteString("false")
|
||||
case 1:
|
||||
b.WriteString("true")
|
||||
default:
|
||||
b.WriteString("!invalid bool")
|
||||
}
|
||||
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
|
||||
var buf [8]byte
|
||||
if len(v.Value) < 8 {
|
||||
copy(buf[:], v.Value)
|
||||
}
|
||||
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
|
||||
default:
|
||||
fmt.Fprintf(&b, "%v", v.Value)
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type MetaEntry struct {
|
||||
Key string
|
||||
Type ValueType
|
||||
Values []MetaValue
|
||||
}
|
||||
|
||||
func (e MetaEntry) String() string {
|
||||
if len(e.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) Uint32() uint32 {
|
||||
if len(e.Values) == 0 {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) FileType() Type {
|
||||
if len(e.Values) == 0 {
|
||||
return TypeCount
|
||||
}
|
||||
return Type(e.Uint32())
|
||||
}
|
||||
|
||||
func (e MetaEntry) GoString() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Key)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Type.String())
|
||||
b.WriteString("(")
|
||||
for i, v := range e.Values {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(v.String())
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
|
||||
|
||||
Version int
|
||||
FileType Type
|
||||
}
|
||||
|
||||
func Stat(path string) (Info, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
return StatReader(f)
|
||||
}
|
||||
|
||||
// StatReader reads the header information from r and returns an Info
|
||||
// struct with the version and file type.
|
||||
//
|
||||
// It returns an error if any.
|
||||
//
|
||||
// As a special case, it returns ErrBadMagic if the file does not start with
|
||||
// the magic bytes. This can be used to detect if the file is not a GGUF
|
||||
// file.
|
||||
func StatReader(r io.ReadSeeker) (Info, error) {
|
||||
if _, err := r.Seek(0, 0); err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
f, err := ReadFile(r)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
info := Info{Version: f.Version()}
|
||||
for m, err := range f.Metadata {
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
if m.Key == "general.file_type" {
|
||||
if m.Type != ValueTypeUint32 {
|
||||
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
|
||||
}
|
||||
info.FileType = m.FileType()
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
version uint32
|
||||
numMetaValues uint64
|
||||
numTensors uint64
|
||||
|
||||
gr *ggufReader
|
||||
}
|
||||
|
||||
// ReadFile reads header information from r and returns a File, ready for
|
||||
// iteration over Metadata and Tensors.
|
||||
func ReadFile(r io.Reader) (*File, error) {
|
||||
f, err := readFile(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) Version() int {
|
||||
return int(f.version)
|
||||
}
|
||||
|
||||
// Metadata iterates over the metadata in the file. It must be exhausted
|
||||
// before calling Tensors.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
|
||||
var n int
|
||||
for range f.numMetaValues {
|
||||
meta, err := f.gr.readMetaEntry()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
|
||||
yield(MetaEntry{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(meta, nil) {
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Tensors iterates over the tensors in the file. It must only be called
|
||||
// after exhausting the metadata iterator.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
|
||||
var last TensorInfo
|
||||
for range f.numTensors {
|
||||
info, err := f.gr.readTensorInfo()
|
||||
|
||||
// If the last tensor had a valid offset, yield it.
|
||||
//
|
||||
// NOTE: No tensor should have an offset of 0 because the
|
||||
// offset is the start of the tensor data which is always
|
||||
// afer the magic bytes, version, numMetaValues, and
|
||||
// numTensors, which MUST all be non-zero bytes as per the
|
||||
// GGUF spec.
|
||||
if last.Offset > 0 {
|
||||
if !yield(last, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
yield(TensorInfo{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Tensor data does not include size, so we need to
|
||||
// calculate it based on the offset of the previous tensor
|
||||
// offset to the current.
|
||||
offset0 := last.Offset
|
||||
last = info
|
||||
last.Size = info.Offset - offset0
|
||||
}
|
||||
if last.Offset > 0 {
|
||||
yield(last, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
|
||||
|
||||
func readFile(r io.Reader) (*File, error) {
|
||||
gr := &ggufReader{r: &reader{r: r}}
|
||||
magic, err := gr.next(4)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, ErrBadMagic)
|
||||
}
|
||||
if !bytes.Equal(magic, magicBytes) {
|
||||
return nil, ErrBadMagic
|
||||
}
|
||||
version, err := gr.readUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 3 {
|
||||
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
|
||||
}
|
||||
numTensors, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numMetaValues, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &File{
|
||||
version: version,
|
||||
|
||||
numMetaValues: numMetaValues,
|
||||
numTensors: numTensors,
|
||||
gr: gr,
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
345
x/encoding/gguf/gguf_test.go
Normal file
345
x/encoding/gguf/gguf_test.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
wantInfo Info
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "valid general.file_type",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"general.file_type" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
||||
"",
|
||||
wantInfo: Info{
|
||||
Version: 3,
|
||||
FileType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info, err := StatReader(strings.NewReader(tt.data))
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
|
||||
wantMeta []MetaEntry
|
||||
wantTensor []TensorInfo
|
||||
wantReadErr error
|
||||
wantMetaErr error
|
||||
wantTensorErr error
|
||||
wantInfo Info
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantReadErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantReadErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantReadErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "no metadata or tensors",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"",
|
||||
wantReadErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good metadata",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "good metadata with multiple values",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// MetaEntry 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"x" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"XX" + // string value
|
||||
|
||||
// MetaEntry 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"y" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x99\x88\x77\x66" + // uint32 value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
||||
Type: ValueTypeString,
|
||||
Value: []byte("XX"),
|
||||
}}},
|
||||
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
||||
Type: ValueTypeUint32,
|
||||
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
||||
}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative string length in meta key",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMetaErr: ErrMangled,
|
||||
},
|
||||
|
||||
// Tensor tests
|
||||
{
|
||||
name: "good tensor",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
// dimensions
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many dimensions",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
"\x00\x00\x00\x01" + // dimensions length
|
||||
"",
|
||||
wantTensorErr: ErrMangled,
|
||||
},
|
||||
{
|
||||
name: "size computed",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
|
||||
// Tensor 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 768,
|
||||
Size: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ReadFile(strings.NewReader(tt.data))
|
||||
if err != nil {
|
||||
if !errors.Is(err, tt.wantReadErr) {
|
||||
t.Fatalf("unexpected ReadFile error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var got []MetaEntry
|
||||
for meta, err := range f.Metadata {
|
||||
if !errors.Is(err, tt.wantMetaErr) {
|
||||
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got = append(got, meta)
|
||||
}
|
||||
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
||||
|
||||
var gotT []TensorInfo
|
||||
for tinfo, err := range f.Tensors {
|
||||
if !errors.Is(err, tt.wantTensorErr) {
|
||||
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gotT = append(gotT, tinfo)
|
||||
}
|
||||
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzReadInfo(f *testing.F) {
|
||||
f.Add(string(magicBytes))
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"")
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
gf, err := ReadFile(strings.NewReader(data))
|
||||
if err != nil {
|
||||
t.Logf("ReadFile error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
for _, err := range gf.Metadata {
|
||||
if err != nil {
|
||||
t.Logf("metadata error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
for tinfo, err := range gf.Tensors {
|
||||
if err != nil {
|
||||
t.Logf("tensor error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Offset <= 0 {
|
||||
t.Logf("invalid tensor offset: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Size <= 0 {
|
||||
t.Logf("invalid tensor size: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
195
x/encoding/gguf/ggufio.go
Normal file
195
x/encoding/gguf/ggufio.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type ggufReader struct {
|
||||
r *reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
|
||||
key, err := r.readString()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
typ, err := r.readValueType()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
var values []MetaValue
|
||||
for v, err := range r.readMetaValues(typ) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return MetaEntry{
|
||||
Key: string(key),
|
||||
Type: typ,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
|
||||
var value []byte
|
||||
var err error
|
||||
switch typ {
|
||||
case ValueTypeUint8, ValueTypeInt8:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeUint16, ValueTypeInt16:
|
||||
value, err = r.next(2)
|
||||
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
|
||||
value, err = r.next(4)
|
||||
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
|
||||
value, err = r.next(8)
|
||||
case ValueTypeBool:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeString:
|
||||
value, err = r.readString()
|
||||
case ValueTypeArray:
|
||||
err = fmt.Errorf("nested arrays are not supported")
|
||||
default:
|
||||
err = fmt.Errorf("unsupported metadata type: %d", typ)
|
||||
}
|
||||
if err != nil {
|
||||
return MetaValue{}, err
|
||||
}
|
||||
return MetaValue{
|
||||
Type: typ,
|
||||
Value: bytes.Clone(value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
|
||||
return func(yield func(MetaValue, error) bool) {
|
||||
if typ == ValueTypeArray {
|
||||
atyp, err := r.readValueType()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid type: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid length: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
v, err := r.readMetaValue(atyp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(v, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v, err := r.readMetaValue(typ)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata value: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
yield(v, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ggufReader) readValueType() (ValueType, error) {
|
||||
typ, err := r.readUint32()
|
||||
return ValueType(typ), err
|
||||
}
|
||||
|
||||
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
|
||||
name, err := r.readString()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
numDimensions, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
if numDimensions > MaxDimensions {
|
||||
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
|
||||
}
|
||||
|
||||
dims := make([]uint64, numDimensions)
|
||||
for i := range dims {
|
||||
d, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
dims[i] = d
|
||||
}
|
||||
typ, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
offset, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): check offset is multiple of ALIGNMENT
|
||||
return TensorInfo{
|
||||
Name: string(name),
|
||||
Dimensions: dims,
|
||||
Type: Type(typ),
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) next(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
|
||||
}
|
||||
w := r.r.window()
|
||||
for len(w) < n {
|
||||
if r.r.extend() == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
w = r.r.window()
|
||||
}
|
||||
r.r.release(n)
|
||||
r.n += n
|
||||
return w[:n], nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readString() ([]byte, error) {
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bmizerany): limit max string length
|
||||
return r.next(int(n))
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint32() (uint32, error) {
|
||||
b, err := r.next(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint32(b)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint64() (uint64, error) {
|
||||
b, err := r.next(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint64(b)
|
||||
return n, nil
|
||||
}
|
||||
70
x/encoding/gguf/reader.go
Normal file
70
x/encoding/gguf/reader.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package gguf
|
||||
|
||||
import "io"
|
||||
|
||||
// A reader implements a sliding window over an io.Reader.
|
||||
type reader struct {
|
||||
data []byte
|
||||
offset int
|
||||
r io.Reader
|
||||
err error
|
||||
}
|
||||
|
||||
// release discards n bytes from the front of the window.
|
||||
func (b *reader) release(n int) {
|
||||
b.offset += n
|
||||
}
|
||||
|
||||
// window returns the current window.
|
||||
// The window is invalidated by calls to release or extend.
|
||||
func (b *reader) window() []byte {
|
||||
return b.data[b.offset:]
|
||||
}
|
||||
|
||||
// tuning constants for byteReader.extend.
|
||||
const (
|
||||
newBufferSize = 8 << 10
|
||||
minReadSize = newBufferSize >> 2
|
||||
)
|
||||
|
||||
// extend extends the window with data from the underlying reader.
|
||||
func (b *reader) extend() int {
|
||||
if b.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := len(b.data) - b.offset
|
||||
if remaining == 0 {
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
if cap(b.data)-len(b.data) >= minReadSize {
|
||||
// nothing to do, enough space exists between len and cap.
|
||||
} else if cap(b.data)-remaining >= minReadSize {
|
||||
// buffer has enough space if we move the data to the front.
|
||||
b.compact()
|
||||
} else {
|
||||
// otherwise, we must allocate/extend a new buffer
|
||||
b.grow()
|
||||
}
|
||||
remaining += b.offset
|
||||
n, err := b.r.Read(b.data[remaining:cap(b.data)])
|
||||
// reduce length to the existing plus the data we read.
|
||||
b.data = b.data[:remaining+n]
|
||||
b.err = err
|
||||
return n
|
||||
}
|
||||
|
||||
// grow grows the buffer, moving the active data to the front.
|
||||
func (b *reader) grow() {
|
||||
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
|
||||
copy(buf, b.data[b.offset:])
|
||||
b.data = buf
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves the active data to the front of the buffer.
|
||||
func (b *reader) compact() {
|
||||
copy(b.data, b.data[b.offset:])
|
||||
b.offset = 0
|
||||
}
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")
|
||||
134
x/model/digest.go
Normal file
134
x/model/digest.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest represents a digest of a model Manifest. It is a comparable value
|
||||
// type and is immutable.
|
||||
//
|
||||
// The zero Digest is not a valid digest.
|
||||
type Digest struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// Type returns the digest type of the digest.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseDigest("sha256-1234").Type() // returns "sha256"
|
||||
func (d Digest) Type() string {
|
||||
typ, _, _ := strings.Cut(d.s, "-")
|
||||
return typ
|
||||
}
|
||||
|
||||
// String returns the digest in the form of "<digest-type>-<digest>", or the
|
||||
// empty string if the digest is invalid.
|
||||
func (d Digest) String() string { return d.s }
|
||||
|
||||
// IsValid returns true if the digest is valid (not zero).
|
||||
//
|
||||
// A valid digest may be created only by ParseDigest, or
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
|
||||
}
|
||||
*d = ParseDigest(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Digest{}
|
||||
_ sql.Scanner = (*Digest)(nil)
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (d *Digest) Scan(src any) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal Scan on valid Digest")
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*d = ParseDigest(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*d = ParseDigest(string(v))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (d Digest) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
@@ -2,6 +2,18 @@ package model
|
||||
|
||||
import "testing"
|
||||
|
||||
// - test scan
|
||||
// - test marshal text
|
||||
// - test unmarshal text
|
||||
// - test log value
|
||||
// - test string
|
||||
// - test type
|
||||
// - test digest
|
||||
// - test valid
|
||||
// - test driver valuer
|
||||
// - test sql scanner
|
||||
// - test parse digest
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {s: "sha256-1234"},
|
||||
@@ -44,3 +56,28 @@ func TestDigestString(t *testing.T) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestUnmarshalText(t *testing.T) {
|
||||
const testDigest = "sha256-1234"
|
||||
t.Run("UnmarshalText (into Valid)", func(t *testing.T) {
|
||||
d := ParseDigest(testDigest)
|
||||
if !d.IsValid() {
|
||||
panic("invalid test")
|
||||
}
|
||||
if err := d.UnmarshalText(nil); err == nil {
|
||||
t.Errorf("UnmarshalText on valid Digest did not return error")
|
||||
}
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText on valid Digest changed Digest: %q", d.String())
|
||||
}
|
||||
})
|
||||
t.Run("UnmarshalText make safe copy", func(t *testing.T) {
|
||||
data := []byte(testDigest)
|
||||
var d Digest
|
||||
d.UnmarshalText(data)
|
||||
data[0] = 'x'
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText did not make a safe copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
132
x/model/file.go
Normal file
132
x/model/file.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Package model implements the File and Name types for working with and
|
||||
// representing Modelfiles and model Names.
|
||||
//
|
||||
// The Name type should be used when working with model names, and the File
|
||||
// type should be used when working with Modelfiles.
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ParamPragma struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type MessagePragma struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
// From is a required pragma that specifies the source of the model,
|
||||
// either on disk, or by reference (see model.ParseName).
|
||||
From string
|
||||
|
||||
// Optional
|
||||
Params []ParamPragma
|
||||
Template string
|
||||
System string
|
||||
Adapter string
|
||||
Messages []MessagePragma
|
||||
|
||||
License string
|
||||
}
|
||||
|
||||
type FileError struct {
|
||||
Pragma string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *FileError) Error() string {
|
||||
return e.Pragma + ": " + e.Message
|
||||
}
|
||||
|
||||
// Pragma represents a single pragma in a Modelfile.
|
||||
type Pragma struct {
|
||||
// The pragma name
|
||||
Name string
|
||||
|
||||
// Args contains the user-defined arguments for the pragma. If no
|
||||
// arguments were provided, it is nil.
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (p Pragma) Arg(i int) string {
|
||||
if i >= len(p.Args) {
|
||||
return ""
|
||||
}
|
||||
return p.Args[i]
|
||||
}
|
||||
|
||||
func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
|
||||
return func(yield func(Pragma, error) bool) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
// TODO(bmizerany): set a max num fields/args to
|
||||
// prevent mem bloat
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
p := Pragma{
|
||||
Name: strings.ToUpper(args[0]),
|
||||
}
|
||||
if p.Name == "MESSAGE" {
|
||||
// handle special case where message content
|
||||
// is space separated on the _rest_ of the
|
||||
// line like: `MESSAGE user Is Ontario in
|
||||
// Canada?`
|
||||
panic("TODO")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
p.Args = args[1:]
|
||||
}
|
||||
if !yield(p, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
yield(Pragma{}, sc.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(r io.Reader) (File, error) {
|
||||
var f File
|
||||
for p, err := range FilePragmas(r) {
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
switch p.Name {
|
||||
case "FROM":
|
||||
f.From = p.Arg(0)
|
||||
case "PARAMETER":
|
||||
f.Params = append(f.Params, ParamPragma{
|
||||
Key: strings.ToLower(p.Arg(0)),
|
||||
Value: p.Arg(1),
|
||||
})
|
||||
case "TEMPLATE":
|
||||
f.Template = p.Arg(0)
|
||||
case "SYSTEM":
|
||||
f.System = p.Arg(0)
|
||||
case "ADAPTER":
|
||||
f.Adapter = p.Arg(0)
|
||||
case "MESSAGE":
|
||||
f.Messages = append(f.Messages, MessagePragma{
|
||||
Role: p.Arg(0),
|
||||
Content: p.Arg(1),
|
||||
})
|
||||
case "LICENSE":
|
||||
f.License = p.Arg(0)
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
@@ -1,45 +1,31 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/types/structs"
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrInvalidName, ErrIncompleteName, and ErrInvalidDigest are not
|
||||
// used by this package, but are exported so that other packages can
|
||||
// use them, instead of defining their own errors for them.
|
||||
ErrInvalidName = errors.New("invalid model name")
|
||||
// ErrInvalidName is not used by this package, but is exported so that
|
||||
// other packages do not need to invent their own error type when they
|
||||
// need to return an error for an invalid name.
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
// MaskDefault is the default mask used by [Name.DisplayShortest].
|
||||
MaskDefault = "registry.ollama.ai/library/?:latest"
|
||||
|
||||
// MaskNothing is a mask that masks nothing.
|
||||
MaskNothing = "?/?/?:?"
|
||||
|
||||
// DefaultFill is the default fill used by [ParseName].
|
||||
FillDefault = "registry.ollama.ai/library/?:latest+Q4_0"
|
||||
|
||||
// FillNothing is a fill that fills nothing.
|
||||
FillNothing = "?/?/?:?+?"
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
|
||||
type PartKind int
|
||||
@@ -55,11 +41,11 @@ const (
|
||||
PartBuild
|
||||
PartDigest
|
||||
|
||||
// NumParts is the number of parts in a Name. In this list, it must
|
||||
// follow the final part.
|
||||
NumParts
|
||||
|
||||
PartExtraneous = -1
|
||||
// Invalid is a special part that is used to indicate that a part is
|
||||
// invalid. It is not a valid part of a Name.
|
||||
//
|
||||
// It should be kept as the last part in the list.
|
||||
PartInvalid
|
||||
)
|
||||
|
||||
var kindNames = map[PartKind]string{
|
||||
@@ -69,6 +55,7 @@ var kindNames = map[PartKind]string{
|
||||
PartTag: "Tag",
|
||||
PartBuild: "Build",
|
||||
PartDigest: "Digest",
|
||||
PartInvalid: "Invalid",
|
||||
}
|
||||
|
||||
func (k PartKind) String() string {
|
||||
@@ -77,7 +64,7 @@ func (k PartKind) String() string {
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
// with the case preserved, but is not directly comparable with other Names
|
||||
// since model names can be represented with different casing depending on
|
||||
// since model names can be represented with different caseing depending on
|
||||
// the use case. For instance, "Mistral" and "mistral" are the same model
|
||||
// but each version may have come from different sources (e.g. copied from a
|
||||
// Web page, or from a file path).
|
||||
@@ -103,19 +90,20 @@ func (k PartKind) String() string {
|
||||
// The parts can be obtained in their original form by calling [Name.Parts].
|
||||
//
|
||||
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
|
||||
//
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [NumParts]string // host, namespace, model, tag, build, digest
|
||||
parts [6]string // host, namespace, model, tag, build
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offsets all into a single uint64 since the first
|
||||
// could pack the offests all into a single uint64 since the first
|
||||
// parts take less bits since their max offset is less than the max
|
||||
// offset of the next part. This would save a ton of bytes per Name
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name, and returns the result of filling it with
|
||||
// defaults. The input string must be a valid string
|
||||
// ParseName parses s into a Name. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
|
||||
@@ -133,7 +121,7 @@ type Name struct {
|
||||
// "mistral:7b+x"
|
||||
// "example.com/mike/mistral:latest+Q4_0"
|
||||
// "example.com/bruce/mistral:latest"
|
||||
// "example.com/pdevine/thisisfine:7b+Q4_0@sha256-1234567890abcdef"
|
||||
// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
|
||||
//
|
||||
// Examples of invalid paths:
|
||||
//
|
||||
@@ -144,71 +132,41 @@ type Name struct {
|
||||
//
|
||||
// It returns the zero value if any part is invalid.
|
||||
//
|
||||
// # Fills
|
||||
//
|
||||
// For any valid s, the fill string is used to fill in missing parts of the
|
||||
// Name. The fill string must be a valid Name with the exception that any part
|
||||
// may be the string ("?"), which will not be considered for filling.
|
||||
func ParseName(s, fill string) Name {
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
func ParseName(s string) Name {
|
||||
var r Name
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
r = Name{}
|
||||
return false
|
||||
for kind, part := range Parts(s) {
|
||||
if kind == PartInvalid {
|
||||
return Name{}
|
||||
}
|
||||
if kind == PartExtraneous || !isValidPart(kind, part) {
|
||||
r = Name{}
|
||||
return false
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[kind] = part
|
||||
return true
|
||||
})
|
||||
}
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
return fillName(r, fill)
|
||||
return r
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
func parseMask(s string) Name {
|
||||
var r Name
|
||||
parts(s)(func(kind PartKind, part string) bool {
|
||||
if part == "?" {
|
||||
// mask part; treat as empty but valid
|
||||
return true
|
||||
}
|
||||
if !isValidPart(kind, part) {
|
||||
panic(fmt.Errorf("invalid mask part %s: %q", kind, part))
|
||||
}
|
||||
r.parts[kind] = part
|
||||
return true
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func MustParseName(s, fill string) Name {
|
||||
r := ParseName(s, fill)
|
||||
func MustParseName(s string) Name {
|
||||
r := ParseName(s)
|
||||
if !r.IsValid() {
|
||||
panic("invalid Name: " + s)
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// fillName fills in the missing parts of dst with the parts of src.
|
||||
// Fill fills in the missing parts of dst with the parts of src.
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
//
|
||||
// It skipps fill parts that are "?".
|
||||
func fillName(r Name, fill string) Name {
|
||||
fill = cmp.Or(fill, FillDefault)
|
||||
f := parseMask(fill)
|
||||
if fill != FillNothing && f.IsZero() {
|
||||
panic("invalid fill")
|
||||
}
|
||||
func Fill(dst, src Name) Name {
|
||||
var r Name
|
||||
for i := range r.parts {
|
||||
if f.parts[i] == "?" {
|
||||
continue
|
||||
}
|
||||
r.parts[i] = cmp.Or(r.parts[i], f.parts[i])
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
@@ -227,14 +185,12 @@ func (r Name) WithDigest(digest Digest) Name {
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convenient way to compare names, use [Name.EqualFold].
|
||||
//
|
||||
//nolint:errcheck
|
||||
// checks. For a convienent way to compare names, use [Name.EqualFold].
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.parts {
|
||||
for _, part := range r.Parts() {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
@@ -253,59 +209,39 @@ func (r Name) slice(from, to PartKind) Name {
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayShortest returns the shortest possible, masked display string in form:
|
||||
//
|
||||
// [host/][<namespace>/]<model>[:<tag>]
|
||||
//
|
||||
// # Masks
|
||||
//
|
||||
// The mask is a string that specifies which parts of the name to omit based
|
||||
// on case-insensitive comparison. [Name.DisplayShortest] omits parts of the name
|
||||
// that are the same as the mask, moving from left to right until the first
|
||||
// unequal part is found. It then moves right to left until the first unequal
|
||||
// part is found. The result is the shortest possible display string.
|
||||
//
|
||||
// Unlike a [Name] the mask can contain "?" characters which are treated as
|
||||
// wildcards. A "?" will never match a part of the name, since a valid name
|
||||
// can never contain a "?" character.
|
||||
//
|
||||
// For example: Given a Name ("registry.ollama.ai/library/mistral:latest") masked
|
||||
// with ("registry.ollama.ai/library/?:latest") will produce the display string
|
||||
// ("mistral").
|
||||
//
|
||||
// If mask is the empty string, then [MaskDefault] is used.
|
||||
//
|
||||
// DisplayShortest panics if the mask is not the empty string, MaskNothing, and
|
||||
// invalid.
|
||||
//
|
||||
// # Builds
|
||||
//
|
||||
// For now, DisplayShortest does consider the build or return one in the
|
||||
// result. We can lift this restriction when needed.
|
||||
func (r Name) DisplayShortest(mask string) string {
|
||||
mask = cmp.Or(mask, MaskDefault)
|
||||
d := parseMask(mask)
|
||||
if mask != MaskNothing && r.IsZero() {
|
||||
panic("invalid Name")
|
||||
}
|
||||
for i := range PartTag {
|
||||
if !strings.EqualFold(r.parts[i], d.parts[i]) {
|
||||
break
|
||||
}
|
||||
r.parts[i] = ""
|
||||
}
|
||||
for i := PartTag; i >= 0; i-- {
|
||||
if !strings.EqualFold(r.parts[i], d.parts[i]) {
|
||||
break
|
||||
}
|
||||
r.parts[i] = ""
|
||||
}
|
||||
return r.slice(PartHost, PartTag).DisplayLong()
|
||||
// DisplayModel returns the a display string composed of the model only.
|
||||
func (r Name) DisplayModel() string {
|
||||
return r.parts[PartModel]
|
||||
}
|
||||
|
||||
// DisplayLongest returns the result of r.DisplayShortest(MaskNothing).
|
||||
func (r Name) DisplayLongest() string {
|
||||
return r.DisplayShortest(MaskNothing)
|
||||
// DisplayFullest returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// It does not include the build part. For the fullest possible display
|
||||
// string with the build, use [Name.String].
|
||||
func (r Name) DisplayFullest() string {
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayShort returns the fullest possible display string in form:
|
||||
//
|
||||
// <model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayShort() string {
|
||||
return r.slice(PartModel, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
//
|
||||
// <namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
return r.slice(PartNamespace, PartTag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
@@ -322,28 +258,23 @@ var seps = [...]string{
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
|
||||
//
|
||||
// Missing parts and their separators are not written.
|
||||
// Missing parts and their seperators are not written.
|
||||
//
|
||||
// The full digest is always prefixed with "@". That is if [Name.IsValid]
|
||||
// reports false and [Name.IsResolved] reports true, then the string is
|
||||
// returned as "@<digest-type>-<digest>".
|
||||
func (r Name) writeTo(w io.StringWriter) error {
|
||||
func (r Name) writeTo(w io.StringWriter) {
|
||||
var partsWritten int
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if partsWritten > 0 || i == int(PartDigest) {
|
||||
if _, err := w.WriteString(seps[i-1]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if _, err := w.WriteString(r.parts[i]); err != nil {
|
||||
return err
|
||||
w.WriteString(seps[i-1])
|
||||
}
|
||||
w.WriteString(r.parts[i])
|
||||
partsWritten++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
@@ -352,29 +283,32 @@ var builderPool = sync.Pool{
|
||||
},
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
//
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
_ = r.writeTo(b)
|
||||
r.writeTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer. It returns a string suitable for
|
||||
// debugging and logging. It is similar to [Name.DisplayLong] but it always
|
||||
// debugging and logging. It is similar to [Name.String] but it always
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return r.DisplayLong()
|
||||
return r.String()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
@@ -382,6 +316,71 @@ func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
b := bufPool.Get().(*bytes.Buffer)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
defer bufPool.Put(b)
|
||||
r.writeTo(b)
|
||||
// TODO: We can remove this alloc if/when
|
||||
// https://github.com/golang/go/issues/62384 lands.
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
//
|
||||
// It is an error to call UnmarshalText on a valid Name.
|
||||
func (r *Name) UnmarshalText(text []byte) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of UnmarshalText is that it should only be
|
||||
// called on an invalid/zero Name. If we allow UnmarshalText
|
||||
// on a valid Name, then the Name will be mutated, breaking
|
||||
// the immutability of the Name.
|
||||
return errors.New("model.Name: illegal UnmarshalText on valid Name")
|
||||
}
|
||||
|
||||
// The contract of UnmarshalText is that we copy to keep the text.
|
||||
*r = ParseName(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Name{}
|
||||
_ sql.Scanner = (*Name)(nil)
|
||||
)
|
||||
|
||||
// Scan implements [database/sql.Scanner].
|
||||
func (r *Name) Scan(src any) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of Scan is that it should only be called on an
|
||||
// invalid/zero Name. If we allow Scan on a valid Name, then the
|
||||
// Name will be mutated, breaking the immutability of the Name.
|
||||
return errors.New("model.Name: illegal Scan on valid Name")
|
||||
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*r = ParseName(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*r = ParseName(string(v))
|
||||
return nil
|
||||
}
|
||||
return errors.New("model.Name: invalid Scan source")
|
||||
}
|
||||
|
||||
// Value implements [database/sql/driver.Valuer].
|
||||
func (r Name) Value() (driver.Value, error) {
|
||||
return r.String(), nil
|
||||
}
|
||||
|
||||
// IsComplete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) IsComplete() bool {
|
||||
@@ -439,39 +438,41 @@ func downcase(r rune) rune {
|
||||
return r
|
||||
}
|
||||
|
||||
func (r Name) Host() string { return r.parts[PartHost] }
|
||||
func (r Name) Namespace() string { return r.parts[PartNamespace] }
|
||||
func (r Name) Model() string { return r.parts[PartModel] }
|
||||
func (r Name) Build() string { return r.parts[PartBuild] }
|
||||
func (r Name) Tag() string { return r.parts[PartTag] }
|
||||
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
|
||||
|
||||
// iter_Seq2 is a iter.Seq2 defined here to avoid the current build
|
||||
// restrictions in the go1.22 iter package requiring the
|
||||
// goexperiment.rangefunc tag to be set via the GOEXPERIMENT=rangefunc flag,
|
||||
// which we are not yet ready to support.
|
||||
// Parts returns the parts of the Name in order of concreteness.
|
||||
//
|
||||
// Once we are ready to support rangefunc, this can be removed and replaced
|
||||
// with the iter.Seq2 type.
|
||||
type iter_Seq2[A, B any] func(func(A, B) bool)
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
// to least specific.
|
||||
//
|
||||
// It normalizes the input string by removing "http://" and "https://" only.
|
||||
// No other normalizations are performed.
|
||||
func parts(s string) iter_Seq2[PartKind, string] {
|
||||
// No other normalization is done.
|
||||
func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = strings.TrimPrefix(s, "http://")
|
||||
} else {
|
||||
s = strings.TrimPrefix(s, "https://")
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
|
||||
if len(s) > MaxNamePartLen || len(s) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
numConsecutiveDots := 0
|
||||
yieldValid := func(kind PartKind, part string) bool {
|
||||
if !isValidPart(kind, part) {
|
||||
yield(PartInvalid, "")
|
||||
return false
|
||||
}
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
@@ -480,7 +481,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
// we don't keep spinning on it, waiting for
|
||||
// an isInValidPart check which would scan
|
||||
// over it again.
|
||||
yield(state, s[i+1:j])
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -488,7 +489,7 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
case '@':
|
||||
switch state {
|
||||
case PartDigest:
|
||||
if !yield(PartDigest, s[i+1:j]) {
|
||||
if !yieldValid(PartDigest, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
if i == 0 {
|
||||
@@ -500,181 +501,85 @@ func parts(s string) iter_Seq2[PartKind, string] {
|
||||
}
|
||||
state, j, partLen = PartBuild, i, 0
|
||||
default:
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '+':
|
||||
switch state {
|
||||
case PartBuild, PartDigest:
|
||||
if !yield(PartBuild, s[i+1:j]) {
|
||||
if !yieldValid(PartBuild, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartTag, i, 0
|
||||
default:
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case PartTag, PartBuild, PartDigest:
|
||||
if !yield(PartTag, s[i+1:j]) {
|
||||
if !yieldValid(PartTag, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartModel, i, 0
|
||||
default:
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case PartModel, PartTag, PartBuild, PartDigest:
|
||||
if !yield(PartModel, s[i+1:j]) {
|
||||
if !yieldValid(PartModel, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j = PartNamespace, i
|
||||
case PartNamespace:
|
||||
if !yield(PartNamespace, s[i+1:j]) {
|
||||
if !yieldValid(PartNamespace, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartHost, i, 0
|
||||
default:
|
||||
yield(PartExtraneous, s[i+1:j])
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
default:
|
||||
if s[i] == '.' {
|
||||
if numConsecutiveDots++; numConsecutiveDots > 1 {
|
||||
yield(state, "")
|
||||
return
|
||||
}
|
||||
} else {
|
||||
numConsecutiveDots = 0
|
||||
if !isValidByte(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state <= PartNamespace {
|
||||
yield(state, s[:j])
|
||||
yieldValid(state, s[:j])
|
||||
} else {
|
||||
yield(PartModel, s[:j])
|
||||
yieldValid(PartModel, s[:j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r Name) IsZero() bool {
|
||||
return r.parts == [NumParts]string{}
|
||||
}
|
||||
|
||||
// IsValid reports if a model has at minimum a valid model part.
|
||||
// IsValid returns true if the Name hPartas a valid nick. To know if a Name is
|
||||
// "complete", use [Name.IsComplete].
|
||||
func (r Name) IsValid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// ParseNameFromURLPath parses forms of a URL path into a Name. Specifically,
|
||||
// it trims any leading "/" and then calls [ParseName] with fill.
|
||||
func ParseNameFromURLPath(s, fill string) Name {
|
||||
s = strings.TrimPrefix(s, "/")
|
||||
return ParseName(s, fill)
|
||||
}
|
||||
|
||||
// URLPath returns a complete, canonicalized, relative URL path using the parts of a
|
||||
// complete Name.
|
||||
//
|
||||
// The parts maintain their original case.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseName("example.com/namespace/model:tag+build").URLPath() // returns "/example.com/namespace/model:tag"
|
||||
func (r Name) URLPath() string {
|
||||
return r.DisplayShortest(MaskNothing)
|
||||
}
|
||||
|
||||
// ParseNameFromFilepath parses a file path into a Name. The input string must be a
|
||||
// valid file path representation of a model name in the form:
|
||||
//
|
||||
// host/namespace/model/tag/build
|
||||
//
|
||||
// The zero valid is returned if s does not contain all path elements
|
||||
// leading up to the model part, or if any path element is an invalid part
|
||||
// for the its corresponding part kind.
|
||||
//
|
||||
// The fill string is used to fill in missing parts of any constructed Name.
|
||||
// See [ParseName] for more information on the fill string.
|
||||
func ParseNameFromFilepath(s, fill string) Name {
|
||||
var r Name
|
||||
for i := range PartBuild + 1 {
|
||||
part, rest, _ := strings.Cut(s, string(filepath.Separator))
|
||||
if !isValidPart(i, part) {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[i] = part
|
||||
s = rest
|
||||
if s == "" {
|
||||
break
|
||||
}
|
||||
}
|
||||
if s != "" {
|
||||
return Name{}
|
||||
}
|
||||
if !r.IsValid() {
|
||||
return Name{}
|
||||
}
|
||||
return fillName(r, fill)
|
||||
}
|
||||
|
||||
// Filepath returns a complete, canonicalized, relative file path using the
|
||||
// parts of a complete Name.
|
||||
//
|
||||
// Each parts is downcased, except for the build part which is upcased.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseName("example.com/namespace/model:tag+build").Filepath() // returns "example.com/namespace/model/tag/BUILD"
|
||||
func (r Name) Filepath() string {
|
||||
for i := range r.parts {
|
||||
if PartKind(i) == PartBuild {
|
||||
r.parts[i] = strings.ToUpper(r.parts[i])
|
||||
} else {
|
||||
r.parts[i] = strings.ToLower(r.parts[i])
|
||||
}
|
||||
}
|
||||
return filepath.Join(r.parts[:]...)
|
||||
}
|
||||
|
||||
// FilepathNoBuild returns a complete, canonicalized, relative file path using
|
||||
// the parts of a complete Name, but without the build part.
|
||||
func (r Name) FilepathNoBuild() string {
|
||||
for i := range PartBuild {
|
||||
r.parts[i] = strings.ToLower(r.parts[i])
|
||||
}
|
||||
return filepath.Join(r.parts[:PartBuild]...)
|
||||
}
|
||||
|
||||
// isValidPart reports if s contains all valid characters for the given
|
||||
// part kind.
|
||||
// isValidPart returns Parttrue if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
var consecutiveDots int
|
||||
for _, c := range []byte(s) {
|
||||
if c == '.' {
|
||||
if consecutiveDots++; consecutiveDots >= 2 {
|
||||
return false
|
||||
}
|
||||
} else {
|
||||
consecutiveDots = 0
|
||||
}
|
||||
if !isValidByteFor(kind, c) {
|
||||
if !isValidByte(kind, c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidByteFor(kind PartKind, c byte) bool {
|
||||
func isValidByte(kind PartKind, c byte) bool {
|
||||
if kind == PartNamespace && c == '.' {
|
||||
return false
|
||||
}
|
||||
572
x/model/name_test.go
Normal file
572
x/model/name_test.go
Normal file
@@ -0,0 +1,572 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
digest string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[PartHost],
|
||||
namespace: p.parts[PartNamespace],
|
||||
model: p.parts[PartModel],
|
||||
tag: p.parts[PartTag],
|
||||
build: p.parts[PartBuild],
|
||||
digest: p.parts[PartDigest],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
"mistral:7b": {model: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {model: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// invalid digest
|
||||
"mistral:latest@invalid256-": {},
|
||||
"mistral:latest@-123": {},
|
||||
"mistral:latest@!-123": {},
|
||||
"mistral:latest@1-!": {},
|
||||
"mistral:latest@": {},
|
||||
|
||||
// resolved
|
||||
"x@sha123-1": {model: "x", digest: "sha123-1"},
|
||||
"@sha456-2": {digest: "sha456-2"},
|
||||
|
||||
"@@sha123-1": {},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
|
||||
// invalid (includes fuzzing trophies)
|
||||
" / / : + ": {},
|
||||
" / : + ": {},
|
||||
" : + ": {},
|
||||
" + ": {},
|
||||
" : ": {},
|
||||
" / ": {},
|
||||
" /": {},
|
||||
"/ ": {},
|
||||
"/": {},
|
||||
":": {},
|
||||
"+": {},
|
||||
|
||||
// (".") in namepsace is not allowed
|
||||
"invalid.com/7b+x": {},
|
||||
|
||||
"invalid:7b+Q4_0:latest": {},
|
||||
"in valid": {},
|
||||
"invalid/y/z/foo": {},
|
||||
"/0": {},
|
||||
"0 /0": {},
|
||||
"0 /": {},
|
||||
"0/": {},
|
||||
":/0": {},
|
||||
"+0/00000": {},
|
||||
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
|
||||
"0//0": {},
|
||||
"m+^^^": {},
|
||||
"file:///etc/passwd": {},
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := PartKind(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
if g := kind.String(); g != name {
|
||||
t.Errorf("%s = %q; want %q", kind, g, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
for kind, part := range Parts(s) {
|
||||
t.Logf("Part: %s: %q", kind, part)
|
||||
}
|
||||
|
||||
name := ParseName(s)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(name.String()).EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
}
|
||||
|
||||
if name.IsValid() && name.DisplayModel() == "" {
|
||||
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
|
||||
} else if !name.IsValid() && name.DisplayModel() != "" {
|
||||
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
|
||||
}
|
||||
|
||||
if name.IsResolved() && !name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
|
||||
} else if !name.IsResolved() && name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
complete bool
|
||||
completeNoBuild bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"incomplete/mistral:7b+x", false, false},
|
||||
{"incomplete/mistral:7b+Q4_0", false, false},
|
||||
{"incomplete:7b+x", false, false},
|
||||
{"complete.com/x/mistral:latest+Q4_0", true, true},
|
||||
{"complete.com/x/mistral:latest", false, true},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
}
|
||||
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
|
||||
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Complete uses Parts which returns a slice, but it should be
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0").IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameLogValue(t *testing.T) {
|
||||
cases := []string{
|
||||
"example.com/library/mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
"mistral:7b+Q4_0",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseName(s)
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("expected log output to contain %q; got %q", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameDisplay(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantShort string
|
||||
wantLong string
|
||||
wantComplete string
|
||||
wantString string
|
||||
wantModel string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "example.com/library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "mistral:latest",
|
||||
wantComplete: "mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
if g := p.DisplayShort(); g != tt.wantShort {
|
||||
t.Errorf("DisplayShort = %q; want %q", g, tt.wantShort)
|
||||
}
|
||||
if g := p.DisplayLong(); g != tt.wantLong {
|
||||
t.Errorf("DisplayLong = %q; want %q", g, tt.wantLong)
|
||||
}
|
||||
if g := p.DisplayFullest(); g != tt.wantComplete {
|
||||
t.Errorf("DisplayFullest = %q; want %q", g, tt.wantComplete)
|
||||
}
|
||||
if g := p.String(); g != tt.in {
|
||||
t.Errorf("String(%q) = %q; want %q", tt.in, g, tt.in)
|
||||
}
|
||||
if g := p.DisplayModel(); g != tt.wantModel {
|
||||
t.Errorf("Model = %q; want %q", g, tt.wantModel)
|
||||
}
|
||||
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNameDisplay(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
r := ParseName("example.com/mistral:7b+Q4_0")
|
||||
b.Run("Short", func(b *testing.B) {
|
||||
for range b.N {
|
||||
keep(r.DisplayShort())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s)
|
||||
if !r0.IsValid() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.Parts() {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.String(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.String())
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
dst string
|
||||
src string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseName(tt.dst), ParseName(tt.src))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTextMarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||
{"mistral:latest", "mistral:latest", nil},
|
||||
{"mistral", "mistral", nil},
|
||||
{"mistral:7b", "mistral:7b", nil},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
got, err := p.MarshalText()
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||
}
|
||||
|
||||
var r Name
|
||||
if err := r.UnmarshalText(got); err != nil {
|
||||
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||
}
|
||||
if !r.EqualFold(p) {
|
||||
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("UnmarshalText into valid Name", func(t *testing.T) {
|
||||
// UnmarshalText should not be called on a valid Name.
|
||||
p := MustParseName("x")
|
||||
if err := p.UnmarshalText([]byte("mistral:latest+Q4_0")); err == nil {
|
||||
t.Error("UnmarshalText() = nil; want error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TextMarshal allocs", func(t *testing.T) {
|
||||
var data []byte
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
if !name.IsComplete() {
|
||||
// sanity check
|
||||
panic("sanity check failed")
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var err error
|
||||
data, err = name.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 0 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
// Currently, the best we can do is 1 alloc.
|
||||
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnmarshalTest makes safe copy", func(t *testing.T) {
|
||||
// UnmarshalText should make a copy of the data.
|
||||
data := []byte("mistral:latest+Q4_0")
|
||||
p := Name{}
|
||||
if err := p.UnmarshalText(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 'x'
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("UnmarshalText() did not make a copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
t.Run("Scan for already valid Name", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err == nil {
|
||||
t.Error("Scan() = nil; want error")
|
||||
}
|
||||
})
|
||||
t.Run("Scan for invalid Name", func(t *testing.T) {
|
||||
p := Name{}
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err != nil {
|
||||
t.Errorf("Scan() = %v; want nil", err)
|
||||
}
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
|
||||
}
|
||||
})
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if g, err := p.Value(); err != nil {
|
||||
t.Errorf("Value() error = %v; want nil", err)
|
||||
} else if g != "x" {
|
||||
t.Errorf("Value() = %q; want %q", g, "x")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
fmt.Println(r)
|
||||
|
||||
// Output:
|
||||
// registry.ollama.com/library/mistral:latest+Q4_0
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseName("mistral:latest+q4").MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4").MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseName("mistral:LATest").MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseName("mistral:latest"),
|
||||
ParseName("mistRal:7b+q4"),
|
||||
ParseName("MIstral:7b"),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// MIstral:7b
|
||||
// mistRal:7b+q4
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func ExampleName_completeAndResolved() {
|
||||
for _, s := range []string{
|
||||
"x/y/z:latest+q4_0@sha123-1",
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
p := ParseName(s)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", p.IsComplete(), p.IsResolved(), p.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// complete:true resolved:true digest:sha123-1
|
||||
// complete:true resolved:false digest:
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayFullest() {
|
||||
for _, s := range []string{
|
||||
"example.com/jmorganca/mistral:latest+Q4_0",
|
||||
"mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
} {
|
||||
fmt.Println(ParseName(s).DisplayFullest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral:latest
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
||||
89
x/oweb/oweb.go
Normal file
89
x/oweb/oweb.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package oweb
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
)
|
||||
|
||||
func Missing(field string) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "missing",
|
||||
Field: field,
|
||||
Message: fmt.Sprintf("%s is required", field),
|
||||
}
|
||||
}
|
||||
|
||||
func Invalid(field, value, format string, args ...any) error {
|
||||
return &ollama.Error{
|
||||
Status: 400,
|
||||
Code: "invalid",
|
||||
Field: field,
|
||||
Value: value,
|
||||
Message: fmt.Sprintf(format, args...),
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience errors
|
||||
var (
|
||||
ErrNotFound = &ollama.Error{Status: 404, Code: "not_found"}
|
||||
ErrInternal = &ollama.Error{Status: 500, Code: "internal_error"}
|
||||
ErrMethodNotAllowed = &ollama.Error{Status: 405, Code: "method_not_allowed"}
|
||||
)
|
||||
|
||||
type HandlerFunc func(w http.ResponseWriter, r *http.Request) error
|
||||
|
||||
func Serve(h HandlerFunc, w http.ResponseWriter, r *http.Request) {
|
||||
if err := h(w, r); err != nil {
|
||||
// TODO: take a slog.Logger
|
||||
log.Printf("error: %v", err)
|
||||
var oe *ollama.Error
|
||||
if !errors.As(err, &oe) {
|
||||
oe = ErrInternal
|
||||
}
|
||||
oe.Status = cmp.Or(oe.Status, 400)
|
||||
w.WriteHeader(oe.Status)
|
||||
if err := EncodeJSON(w, oe); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeUserJSON[T any](field string, r io.Reader) (*T, error) {
|
||||
v, err := DecodeJSON[T](r)
|
||||
|
||||
// Handle common JSON syntax errors
|
||||
var e *json.SyntaxError
|
||||
if errors.As(err, &e) {
|
||||
return nil, Invalid(field, "", e.Error())
|
||||
}
|
||||
|
||||
// Handle type errors
|
||||
var se *json.UnmarshalTypeError
|
||||
if errors.As(err, &se) {
|
||||
return nil, Invalid(field, se.Value, "expected %s", se.Type)
|
||||
}
|
||||
|
||||
// Return v and err as they were.
|
||||
return v, err
|
||||
}
|
||||
|
||||
func DecodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v *T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
var zero T
|
||||
return &zero, err
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
func EncodeJSON(w io.Writer, v any) error {
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
46
x/registry/apitype/apitype.go
Normal file
46
x/registry/apitype/apitype.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package apitype
|
||||
|
||||
import "encoding/json"
|
||||
|
||||
type Manifest struct {
|
||||
Layers []Layer `json:"layers"`
|
||||
}
|
||||
|
||||
type CompletePart struct {
|
||||
URL string `json:"url"` // contains partNumber and uploadId from server
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
type Layer struct {
|
||||
Digest string `json:"digest"`
|
||||
MediaType string `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"ref"`
|
||||
Manifest json.RawMessage `json:"manifest"`
|
||||
|
||||
// Parts is a list of upload parts that the client upload in the previous
|
||||
// push.
|
||||
CompleteParts []CompletePart `json:"part_uploads"`
|
||||
}
|
||||
|
||||
type Requirement struct {
|
||||
Digest string `json:"digest"`
|
||||
Offset int64 `json:"offset"`
|
||||
Size int64 `json:"Size"`
|
||||
|
||||
// URL is the url to PUT the layer to.
|
||||
//
|
||||
// Clients must include it as the URL, alond with the ETag in the
|
||||
// response headers from the PUT request, in the next push request
|
||||
// in the Uploaded field.
|
||||
URL string `json:"url"`
|
||||
}
|
||||
|
||||
type PushResponse struct {
|
||||
// Requirements is a list of digests that the client needs to push before
|
||||
// repushing the manifest.
|
||||
Requirements []Requirement `json:"requirements,omitempty"`
|
||||
}
|
||||
102
x/registry/client.go
Normal file
102
x/registry/client.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/xml"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
BaseURL string
|
||||
HTTPClient *http.Client
|
||||
}
|
||||
|
||||
func (c *Client) oclient() *ollama.Client {
|
||||
return (*ollama.Client)(c)
|
||||
}
|
||||
|
||||
type PushParams struct {
|
||||
CompleteParts []apitype.CompletePart
|
||||
}
|
||||
|
||||
// Push pushes a manifest to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string, manifest []byte, p *PushParams) ([]apitype.Requirement, error) {
|
||||
p = cmp.Or(p, &PushParams{})
|
||||
// TODO(bmizerany): backoff
|
||||
v, err := ollama.Do[apitype.PushResponse](ctx, c.oclient(), "POST", "/v1/push", &apitype.PushRequest{
|
||||
Name: ref,
|
||||
Manifest: manifest,
|
||||
CompleteParts: p.CompleteParts,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return v.Requirements, nil
|
||||
}
|
||||
|
||||
func PushLayer(ctx context.Context, body io.ReaderAt, url string, off, n int64) (apitype.CompletePart, error) {
|
||||
var zero apitype.CompletePart
|
||||
if off < 0 {
|
||||
return zero, errors.New("off must be >0")
|
||||
}
|
||||
|
||||
file := io.NewSectionReader(body, off, n)
|
||||
req, err := http.NewRequest("PUT", url, file)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
req.ContentLength = n
|
||||
|
||||
// TODO(bmizerany): take content type param
|
||||
req.Header.Set("Content-Type", "text/plain")
|
||||
|
||||
if n >= 0 {
|
||||
req.Header.Set("x-amz-copy-source-range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
|
||||
}
|
||||
|
||||
res, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return zero, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
if res.StatusCode != 200 {
|
||||
e := parseS3Error(res)
|
||||
return zero, fmt.Errorf("unexpected status code: %d; %w", res.StatusCode, e)
|
||||
}
|
||||
etag := strings.Trim(res.Header.Get("ETag"), `"`)
|
||||
cp := apitype.CompletePart{
|
||||
URL: url,
|
||||
ETag: etag,
|
||||
// TODO(bmizerany): checksum
|
||||
}
|
||||
return cp, nil
|
||||
}
|
||||
|
||||
type s3Error struct {
|
||||
XMLName xml.Name `xml:"Error"`
|
||||
Code string `xml:"Code"`
|
||||
Message string `xml:"Message"`
|
||||
Resource string `xml:"Resource"`
|
||||
RequestId string `xml:"RequestId"`
|
||||
}
|
||||
|
||||
func (e *s3Error) Error() string {
|
||||
return fmt.Sprintf("S3 (%s): %s: %s: %s", e.RequestId, e.Resource, e.Code, e.Message)
|
||||
}
|
||||
|
||||
// parseS3Error parses an XML error response from S3.
|
||||
func parseS3Error(res *http.Response) error {
|
||||
var se *s3Error
|
||||
if err := xml.NewDecoder(res.Body).Decode(&se); err != nil {
|
||||
return err
|
||||
}
|
||||
return se
|
||||
}
|
||||
256
x/registry/server.go
Normal file
256
x/registry/server.go
Normal file
@@ -0,0 +1,256 @@
|
||||
// Package implements an Ollama registry client and server package registry
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
)
|
||||
|
||||
// Defaults
|
||||
const (
|
||||
DefaultUploadChunkSize = 50 * 1024 * 1024
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
UploadChunkSize int64 // default is DefaultUploadChunkSize
|
||||
S3Client *minio.Client
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
if err := s.serveHTTP(w, r); err != nil {
|
||||
log.Printf("error: %v", err) // TODO(bmizerany): take a slog.Logger
|
||||
var e *ollama.Error
|
||||
if !errors.As(err, &e) {
|
||||
e = oweb.ErrInternal
|
||||
}
|
||||
w.WriteHeader(cmp.Or(e.Status, 400))
|
||||
if err := oweb.EncodeJSON(w, e); err != nil {
|
||||
log.Printf("error encoding error: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
case "/v1/pull":
|
||||
return s.handlePull(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) uploadChunkSize() int64 {
|
||||
return cmp.Or(s.UploadChunkSize, DefaultUploadChunkSize)
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(w http.ResponseWriter, r *http.Request) error {
|
||||
const bucketTODO = "test"
|
||||
const minimumMultipartSize = 5 * 1024 * 1024 // S3 spec
|
||||
|
||||
pr, err := oweb.DecodeUserJSON[apitype.PushRequest]("", r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mp := model.ParseName(pr.Name)
|
||||
if !mp.IsComplete() {
|
||||
return oweb.Invalid("name", pr.Name, "must be complete")
|
||||
}
|
||||
|
||||
m, err := oweb.DecodeUserJSON[apitype.Manifest]("manifest", bytes.NewReader(pr.Manifest))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mcc := &minio.Core{Client: s.s3()}
|
||||
// TODO(bmizerany): complete uploads before stats for any with ETag
|
||||
|
||||
type completeParts struct {
|
||||
key string
|
||||
parts []minio.CompletePart
|
||||
}
|
||||
|
||||
completePartsByUploadID := make(map[string]completeParts)
|
||||
for _, mcp := range pr.CompleteParts {
|
||||
// parse the URL
|
||||
u, err := url.Parse(mcp.URL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
|
||||
// Check if this is a part upload, if not, skip
|
||||
uploadID := q.Get("uploadId")
|
||||
if uploadID == "" {
|
||||
// not a part upload
|
||||
continue
|
||||
}
|
||||
|
||||
// PartNumber is required
|
||||
queryPartNumber := q.Get("partNumber")
|
||||
partNumber, err := strconv.Atoi(queryPartNumber)
|
||||
if err != nil {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "")
|
||||
}
|
||||
if partNumber < 1 {
|
||||
return oweb.Invalid("partNumber", queryPartNumber, "must be >= 1")
|
||||
}
|
||||
|
||||
// ETag is required
|
||||
if mcp.ETag == "" {
|
||||
return oweb.Missing("etag")
|
||||
}
|
||||
|
||||
cp := completePartsByUploadID[uploadID]
|
||||
cp.key = u.Path
|
||||
cp.parts = append(cp.parts, minio.CompletePart{
|
||||
PartNumber: partNumber,
|
||||
ETag: mcp.ETag,
|
||||
})
|
||||
completePartsByUploadID[uploadID] = cp
|
||||
}
|
||||
|
||||
for uploadID, cp := range completePartsByUploadID {
|
||||
var zeroOpts minio.PutObjectOptions
|
||||
|
||||
// TODO: gross fix!!!!!!!!!!!!!!!
|
||||
key := strings.TrimPrefix(cp.key, "/"+bucketTODO+"/")
|
||||
|
||||
fmt.Printf("Completing multipart upload %s %s %v\n", bucketTODO, key, cp.parts)
|
||||
_, err := mcc.CompleteMultipartUpload(r.Context(), bucketTODO, key, uploadID, cp.parts, zeroOpts)
|
||||
if err != nil {
|
||||
var e minio.ErrorResponse
|
||||
if errors.As(err, &e) && e.Code == "NoSuchUpload" {
|
||||
return oweb.Invalid("uploadId", uploadID, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
var requirements []apitype.Requirement
|
||||
for _, l := range m.Layers {
|
||||
// TODO(bmizerany): do in parallel
|
||||
if l.Size == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO(bmizerany): "global" throttle of rate of transfer
|
||||
pushed, err := s.statObject(r.Context(), l.Digest)
|
||||
if err != nil {
|
||||
println("ERROR:", "statObject", err)
|
||||
return err
|
||||
}
|
||||
if !pushed {
|
||||
key := path.Join("blobs", l.Digest)
|
||||
if l.Size < minimumMultipartSize {
|
||||
// single part upload
|
||||
fmt.Printf("Presigning single %s %s\n", bucketTODO, key)
|
||||
signedURL, err := s.s3().PresignedPutObject(r.Context(), bucketTODO, key, 15*time.Minute)
|
||||
if err != nil {
|
||||
println("ERROR:", "presign single", err)
|
||||
return err
|
||||
}
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Size: l.Size,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
} else {
|
||||
uploadID, err := mcc.NewMultipartUpload(r.Context(), bucketTODO, key, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("Presigning multi %s %s %s\n", bucketTODO, key, uploadID)
|
||||
for partNumber, c := range upload.Chunks(l.Size, s.uploadChunkSize()) {
|
||||
const timeToStartUpload = 15 * time.Minute
|
||||
|
||||
signedURL, err := s.s3().Presign(r.Context(), "PUT", bucketTODO, key, timeToStartUpload, url.Values{
|
||||
"partNumber": []string{strconv.Itoa(partNumber)},
|
||||
"uploadId": []string{uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
println("ERROR:", "presign multi", err)
|
||||
return err
|
||||
}
|
||||
|
||||
requirements = append(requirements, apitype.Requirement{
|
||||
Digest: l.Digest,
|
||||
Offset: c.Offset,
|
||||
Size: c.N,
|
||||
URL: signedURL.String(),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(requirements) == 0 {
|
||||
// Commit the manifest
|
||||
body := bytes.NewReader(pr.Manifest)
|
||||
path := path.Join("manifests", path.Join(mp.Parts()...))
|
||||
_, err := s.s3().PutObject(r.Context(), bucketTODO, path, body, int64(len(pr.Manifest)), minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return oweb.EncodeJSON(w, &apitype.PushResponse{Requirements: requirements})
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
// lookup manifest
|
||||
panic("TODO")
|
||||
}
|
||||
|
||||
func (s *Server) statObject(ctx context.Context, digest string) (pushed bool, err error) {
|
||||
// HEAD the object
|
||||
path := path.Join("blobs", digest)
|
||||
_, err = s.s3().StatObject(ctx, "test", path, minio.StatObjectOptions{})
|
||||
if err != nil {
|
||||
if isNoSuchKey(err) {
|
||||
err = nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func isNoSuchKey(err error) bool {
|
||||
var e minio.ErrorResponse
|
||||
return errors.As(err, &e) && e.Code == "NoSuchKey"
|
||||
}
|
||||
|
||||
func (s *Server) s3() *minio.Client {
|
||||
if s.S3Client != nil {
|
||||
return s.S3Client
|
||||
}
|
||||
s3, err := minio.New("localhost:9000", &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return s3
|
||||
}
|
||||
473
x/registry/server_test.go
Normal file
473
x/registry/server_test.go
Normal file
@@ -0,0 +1,473 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"syscall"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/minio/minio-go/v7"
|
||||
"github.com/minio/minio-go/v7/pkg/credentials"
|
||||
"github.com/ollama/ollama/x/registry/apitype"
|
||||
"github.com/ollama/ollama/x/utils/backoff"
|
||||
"github.com/ollama/ollama/x/utils/upload"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
// const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
// const manifest = `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`
|
||||
|
||||
// ts := newTestServer(t)
|
||||
// ts.pushNotOK(ref, `{}`, &ollama.Error{
|
||||
// Status: 400,
|
||||
// Code: "invalid",
|
||||
// Message: "name must be fully qualified",
|
||||
// })
|
||||
|
||||
// ts.push(ref, `{
|
||||
// "layers": [
|
||||
// {"digest": "sha256-1", "size": 1},
|
||||
// {"digest": "sha256-2", "size": 2},
|
||||
// {"digest": "sha256-3", "size": 3}
|
||||
// ]
|
||||
// }`)
|
||||
|
||||
type tWriter struct {
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (w tWriter) Write(p []byte) (n int, err error) {
|
||||
w.t.Logf("%s", p)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func TestPushBasic(t *testing.T) {
|
||||
const MB = 1024 * 1024
|
||||
|
||||
mc := startMinio(t, true)
|
||||
|
||||
defer func() {
|
||||
mcc := &minio.Core{Client: mc}
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
const ref = "registry.ollama.ai/x/y:latest+Z"
|
||||
|
||||
// Upload two small layers and one large layer that will
|
||||
// trigger a multipart upload.
|
||||
manifest := []byte(`{
|
||||
"layers": [
|
||||
{"digest": "sha256-1", "size": 1},
|
||||
{"digest": "sha256-2", "size": 2},
|
||||
{"digest": "sha256-3", "size": 11000000}
|
||||
]
|
||||
}`)
|
||||
|
||||
hs := httptest.NewServer(&Server{
|
||||
S3Client: mc,
|
||||
UploadChunkSize: 5 * MB,
|
||||
})
|
||||
t.Cleanup(hs.Close)
|
||||
c := &Client{BaseURL: hs.URL}
|
||||
|
||||
requirements, err := c.Push(context.Background(), ref, manifest, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(requirements) < 3 {
|
||||
t.Errorf("expected at least 3 requirements; got %d", len(requirements))
|
||||
t.Logf("requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var uploaded []apitype.CompletePart
|
||||
for i, r := range requirements {
|
||||
t.Logf("[%d] pushing layer: offset=%d size=%d", i, r.Offset, r.Size)
|
||||
|
||||
cp, err := PushLayer(context.Background(), &abcReader{}, r.URL, r.Offset, r.Size)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
uploaded = append(uploaded, cp)
|
||||
}
|
||||
|
||||
requirements, err = c.Push(context.Background(), ref, manifest, &PushParams{
|
||||
CompleteParts: uploaded,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(requirements) != 0 {
|
||||
t.Errorf("unexpected requirements: %v", requirements)
|
||||
}
|
||||
|
||||
var paths []string
|
||||
keys := mc.ListObjects(context.Background(), "test", minio.ListObjectsOptions{
|
||||
Recursive: true,
|
||||
})
|
||||
for k := range keys {
|
||||
paths = append(paths, k.Key)
|
||||
}
|
||||
|
||||
t.Logf("paths: %v", paths)
|
||||
|
||||
diff.Test(t, t.Errorf, paths, []string{
|
||||
"blobs/sha256-1",
|
||||
"blobs/sha256-2",
|
||||
"blobs/sha256-3",
|
||||
"manifests/registry.ollama.ai/x/y/latest/Z",
|
||||
})
|
||||
|
||||
obj, err := mc.GetObject(context.Background(), "test", "manifests/registry.ollama.ai/x/y/latest/Z", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
var gotM apitype.Manifest
|
||||
if err := json.NewDecoder(obj).Decode(&gotM); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, gotM, apitype.Manifest{
|
||||
Layers: []apitype.Layer{
|
||||
{Digest: "sha256-1", Size: 1},
|
||||
{Digest: "sha256-2", Size: 2},
|
||||
{Digest: "sha256-3", Size: 11000000},
|
||||
},
|
||||
})
|
||||
|
||||
// checksum the blobs
|
||||
for i, l := range gotM.Layers {
|
||||
obj, err := mc.GetObject(context.Background(), "test", "blobs/"+l.Digest, minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
info, err := obj.Stat()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("[%d] layer info: name=%q l.Size=%d size=%d", i, info.Key, l.Size, info.Size)
|
||||
|
||||
if msg := checkABCs(obj, int(l.Size)); msg != "" {
|
||||
t.Errorf("[%d] %s", i, msg)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TestBasicPresignS3MultipartReferenceDoNotDelete tests the basic flow of
|
||||
// presigning a multipart upload, uploading the parts, and completing the
|
||||
// upload. It is for future reference and should not be deleted. This flow
|
||||
// is tricky and if we get it wrong in our server, we can refer back to this
|
||||
// as a "back to basics" test/reference.
|
||||
func TestBasicPresignS3MultipartReferenceDoNotDelete(t *testing.T) {
|
||||
t.Skip("skipping reference test; unskip when needed")
|
||||
|
||||
mc := startMinio(t, true)
|
||||
mcc := &minio.Core{Client: mc}
|
||||
|
||||
uploadID, err := mcc.NewMultipartUpload(context.Background(), "test", "theKey", minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
var completed []minio.CompletePart
|
||||
const size int64 = 10 * 1024 * 1024
|
||||
const chunkSize = 5 * 1024 * 1024
|
||||
|
||||
for partNumber, c := range upload.Chunks(size, chunkSize) {
|
||||
u, err := mcc.Presign(context.Background(), "PUT", "test", "theKey", 15*time.Minute, url.Values{
|
||||
"partNumber": {strconv.Itoa(partNumber)},
|
||||
"uploadId": {uploadID},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("[partNumber=%d]: %v", partNumber, u)
|
||||
|
||||
var body abcReader
|
||||
cp, err := PushLayer(context.Background(), &body, u.String(), c.Offset, c.N)
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
t.Logf("completed part: %v", cp)
|
||||
|
||||
// behave like server here (don't cheat and use partNumber)
|
||||
// instead get partNumber from the URL
|
||||
retPartNumber, err := strconv.Atoi(u.Query().Get("partNumber"))
|
||||
if err != nil {
|
||||
t.Fatalf("[partNumber=%d]: %v", partNumber, err)
|
||||
}
|
||||
|
||||
completed = append(completed, minio.CompletePart{
|
||||
PartNumber: retPartNumber,
|
||||
ETag: cp.ETag,
|
||||
})
|
||||
}
|
||||
|
||||
defer func() {
|
||||
// fail if there are any incomplete uploads
|
||||
for x := range mcc.ListIncompleteUploads(context.Background(), "test", "theKey", true) {
|
||||
t.Errorf("incomplete: %v", x)
|
||||
}
|
||||
}()
|
||||
|
||||
info, err := mcc.CompleteMultipartUpload(context.Background(), "test", "theKey", uploadID, completed, minio.PutObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Logf("completed: %v", info)
|
||||
|
||||
// Check key in bucket
|
||||
obj, err := mc.GetObject(context.Background(), "test", "theKey", minio.GetObjectOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer obj.Close()
|
||||
|
||||
h := sha256.New()
|
||||
if _, err := io.Copy(h, obj); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotSum := h.Sum(nil)
|
||||
|
||||
h.Reset()
|
||||
var body abcReader
|
||||
if _, err := io.CopyN(h, &body, size); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
wantSum := h.Sum(nil)
|
||||
|
||||
if !bytes.Equal(gotSum, wantSum) {
|
||||
t.Errorf("got sum = %x; want %x", gotSum, wantSum)
|
||||
}
|
||||
}
|
||||
|
||||
func availableAddr() string {
|
||||
l, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return l.Addr().String()
|
||||
}
|
||||
|
||||
// tracing is "experimental" and may be removed in the future, I can't get it to
|
||||
// work consistently, but I'm leaving it in for now.
|
||||
func startMinio(t *testing.T, trace bool) *minio.Client {
|
||||
t.Helper()
|
||||
|
||||
// Trace is enabled by setting the OLLAMA_MINIO_TRACE environment or
|
||||
// explicitly setting trace to true.
|
||||
trace = cmp.Or(trace, os.Getenv("OLLAMA_MINIO_TRACE") != "")
|
||||
|
||||
dir := t.TempDir()
|
||||
|
||||
t.Cleanup(func() {
|
||||
// TODO(bmizerany): trim temp dir based on dates so that
|
||||
// future runs may be able to inspect results for some time.
|
||||
})
|
||||
|
||||
waitAndMaybeLogError := func(cmd *exec.Cmd) {
|
||||
if err := cmd.Wait(); err != nil {
|
||||
var e *exec.ExitError
|
||||
if errors.As(err, &e) {
|
||||
if e.Exited() {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
t.Logf("startMinio: %s exit status: %v", cmd.Path, e.ExitCode())
|
||||
t.Logf("startMinio: %s exited: %v", cmd.Path, e.Exited())
|
||||
t.Logf("startMinio: %s stderr: %s", cmd.Path, e.Stderr)
|
||||
} else {
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
t.Logf("startMinio: %s exit error: %v", cmd.Path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel must be called first so do wait to add to Cleanup
|
||||
// stack as last cleanup.
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
deadline, ok := t.Deadline()
|
||||
if ok {
|
||||
ctx, cancel = context.WithDeadline(ctx, deadline.Add(-100*time.Millisecond))
|
||||
}
|
||||
|
||||
t.Logf(">> minio: minio server %s", dir)
|
||||
|
||||
addr := availableAddr()
|
||||
cmd := exec.CommandContext(ctx, "minio", "server", "--address", addr, dir)
|
||||
cmd.Env = os.Environ()
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
})
|
||||
|
||||
mc, err := minio.New(addr, &minio.Options{
|
||||
Creds: credentials.NewStaticV4("minioadmin", "minioadmin", ""),
|
||||
Secure: false,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
// wait for server to start with exponential backoff
|
||||
for _, err := range backoff.Upto(ctx, 1*time.Second) {
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
// try list buckets to see if server is up
|
||||
if _, err := mc.ListBuckets(ctx); err == nil {
|
||||
break
|
||||
}
|
||||
t.Logf("startMinio: server is offline; retrying")
|
||||
}
|
||||
|
||||
if trace {
|
||||
cmd := exec.CommandContext(ctx, "mc", "admin", "trace", "--verbose", "test")
|
||||
cmd.Env = append(os.Environ(),
|
||||
"MC_HOST_test=http://minioadmin:minioadmin@"+addr,
|
||||
)
|
||||
cmd.WaitDelay = 3 * time.Second
|
||||
cmd.Cancel = func() error {
|
||||
return cmd.Process.Signal(syscall.SIGQUIT)
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
if err := cmd.Start(); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
|
||||
doneLogging := make(chan struct{})
|
||||
sc := bufio.NewScanner(stdout)
|
||||
go func() {
|
||||
defer close(doneLogging)
|
||||
|
||||
// Scan lines until the process exits.
|
||||
for sc.Scan() {
|
||||
t.Logf("startMinio: mc trace: %s", sc.Text())
|
||||
}
|
||||
_ = sc.Err() // ignore (not important)
|
||||
}()
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
waitAndMaybeLogError(cmd)
|
||||
|
||||
// Make sure we do not log after test exists to
|
||||
// avoid panic.
|
||||
<-doneLogging
|
||||
})
|
||||
}
|
||||
|
||||
if err := mc.MakeBucket(context.Background(), "test", minio.MakeBucketOptions{}); err != nil {
|
||||
t.Fatalf("startMinio: %v", err)
|
||||
}
|
||||
return mc
|
||||
}
|
||||
|
||||
// contextForTest returns a context that is canceled when the test deadline,
|
||||
// if any, is reached. The returned doneLogging function should be called
|
||||
// after all Log/Error/Fatalf calls are done before the test returns.
|
||||
func contextForTest(t *testing.T) (_ context.Context, doneLogging func()) {
|
||||
done := make(chan struct{})
|
||||
deadline, ok := t.Deadline()
|
||||
if !ok {
|
||||
return context.Background(), func() {}
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithDeadline(context.Background(), deadline.Add(-100*time.Millisecond))
|
||||
t.Cleanup(func() {
|
||||
cancel()
|
||||
<-done
|
||||
})
|
||||
return ctx, func() { close(done) }
|
||||
}
|
||||
|
||||
// abcReader repeats the string s infinitely.
|
||||
type abcReader struct {
|
||||
pos int
|
||||
}
|
||||
|
||||
const theABCs = "abcdefghijklmnopqrstuvwxyz"
|
||||
|
||||
func (r *abcReader) Read(p []byte) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[r.pos]
|
||||
r.pos++
|
||||
if r.pos == len(theABCs) {
|
||||
r.pos = 0
|
||||
}
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (r *abcReader) ReadAt(p []byte, off int64) (n int, err error) {
|
||||
for i := range p {
|
||||
p[i] = theABCs[(off+int64(i))%int64(len(theABCs))]
|
||||
}
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func checkABCs(r io.Reader, size int) (reason string) {
|
||||
h := sha256.New()
|
||||
n, err := io.CopyN(h, &abcReader{}, int64(size))
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
panic("short read; should not happen")
|
||||
}
|
||||
want := h.Sum(nil)
|
||||
h = sha256.New()
|
||||
n, err = io.Copy(h, r)
|
||||
if err != nil {
|
||||
return err.Error()
|
||||
}
|
||||
if n != int64(size) {
|
||||
return fmt.Sprintf("got len(r) = %d; want %d", n, size)
|
||||
}
|
||||
got := h.Sum(nil)
|
||||
if !bytes.Equal(got, want) {
|
||||
return fmt.Sprintf("got sum = %x; want %x", got, want)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
4
x/types/empty/message.go
Normal file
4
x/types/empty/message.go
Normal file
@@ -0,0 +1,4 @@
|
||||
package empty
|
||||
|
||||
// Message is a placeholder type used when encoding json messages.
|
||||
type Message struct{}
|
||||
12
x/types/they/want.go
Normal file
12
x/types/they/want.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package they
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Want returns true if the request method is method and the request path
|
||||
// starts with pathPrefix.
|
||||
func Want(r *http.Request, method string, pathPrefix string) bool {
|
||||
return r.Method == method && strings.HasPrefix(r.URL.Path, pathPrefix)
|
||||
}
|
||||
58
x/utils/backoff/backoff.go
Normal file
58
x/utils/backoff/backoff.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package backoff
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"iter"
|
||||
"math/rand"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrMaxAttempts is not used by backoff but is available for use by
|
||||
// callers that want to signal that a maximum number of retries has
|
||||
// been exceeded. This should eliminate the need for callers to invent
|
||||
// their own error.
|
||||
ErrMaxAttempts = errors.New("max retries exceeded")
|
||||
)
|
||||
|
||||
// Upto implements a backoff strategy that yields nil errors until the
|
||||
// context is canceled, the maxRetries is exceeded, or yield returns false.
|
||||
//
|
||||
// The backoff strategy is a simple exponential backoff with a maximum
|
||||
// backoff of maxBackoff. The backoff is randomized between 0.5-1.5 times
|
||||
// the current backoff, in order to prevent accidental "thundering herd"
|
||||
// problems.
|
||||
func Upto(ctx context.Context, maxBackoff time.Duration) iter.Seq2[int, error] {
|
||||
var n int
|
||||
return func(yield func(int, error) bool) {
|
||||
for {
|
||||
if ctx.Err() != nil {
|
||||
yield(n, ctx.Err())
|
||||
return
|
||||
}
|
||||
|
||||
n++
|
||||
|
||||
// n^2 backoff timer is a little smoother than the
|
||||
// common choice of 2^n.
|
||||
d := time.Duration(n*n) * 10 * time.Millisecond
|
||||
if d > maxBackoff {
|
||||
d = maxBackoff
|
||||
}
|
||||
// Randomize the delay between 0.5-1.5 x msec, in order
|
||||
// to prevent accidental "thundering herd" problems.
|
||||
d = time.Duration(float64(d) * (rand.Float64() + 0.5))
|
||||
t := time.NewTimer(d)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
case <-t.C:
|
||||
if !yield(n, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
29
x/utils/upload/upload.go
Normal file
29
x/utils/upload/upload.go
Normal file
@@ -0,0 +1,29 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"iter"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Chunk[I constraints.Integer] struct {
|
||||
Offset I
|
||||
N I
|
||||
}
|
||||
|
||||
// Chunks yields a sequence of a part number and a Chunk. The Chunk is the offset
|
||||
// and size of the chunk. The last chunk may be smaller than chunkSize if size is
|
||||
// not a multiple of chunkSize.
|
||||
//
|
||||
// The first part number is 1 and increases monotonically.
|
||||
func Chunks[I constraints.Integer](size, chunkSize I) iter.Seq2[int, Chunk[I]] {
|
||||
return func(yield func(int, Chunk[I]) bool) {
|
||||
var n int
|
||||
for off := I(0); off < size; off += chunkSize {
|
||||
n++
|
||||
if !yield(n, Chunk[I]{off, min(chunkSize, size-off)}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user