Compare commits
184 Commits
modelfile-
...
bmizerany/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2cf25caf7 | ||
|
|
aca112a308 | ||
|
|
1389e6926a | ||
|
|
a292cde2f3 | ||
|
|
ab9e476551 | ||
|
|
d721228a2b | ||
|
|
d3c6400487 | ||
|
|
06c21f00eb | ||
|
|
8b62eaf059 | ||
|
|
6c1c0f9f1a | ||
|
|
6a4b3c3823 | ||
|
|
38e7ddb39d | ||
|
|
f595dea189 | ||
|
|
be7fe0d6d8 | ||
|
|
98dbc1202a | ||
|
|
bdff89bc4c | ||
|
|
2100129e83 | ||
|
|
4eb7acf84b | ||
|
|
ff68227ca1 | ||
|
|
d2aef85dda | ||
|
|
1407fd3d4a | ||
|
|
07f27312fa | ||
|
|
6ba495d4a3 | ||
|
|
bd446a72cc | ||
|
|
5615f60bb0 | ||
|
|
348378ef56 | ||
|
|
81e8c49ac2 | ||
|
|
0172726a58 | ||
|
|
c84f9b07b0 | ||
|
|
712eaa4b09 | ||
|
|
2f241692bd | ||
|
|
d35a6a577f | ||
|
|
14a6f85e9e | ||
|
|
45d8d22785 | ||
|
|
e201627c63 | ||
|
|
5e76860c47 | ||
|
|
fb0782b7a9 | ||
|
|
0bee38f6b5 | ||
|
|
b24f1ad587 | ||
|
|
0bea2b8916 | ||
|
|
e4d65d5aef | ||
|
|
6e464ebef8 | ||
|
|
f2c17682b0 | ||
|
|
f0e6c563e2 | ||
|
|
a187851900 | ||
|
|
2e600aa398 | ||
|
|
0c78e6c23d | ||
|
|
c0d4f55f3e | ||
|
|
d67ff60643 | ||
|
|
3cb07b3ac9 | ||
|
|
fc595d89bf | ||
|
|
e1de015fbc | ||
|
|
1c04951cac | ||
|
|
95559adee3 | ||
|
|
9821ca28e8 | ||
|
|
c5768ceffe | ||
|
|
0b65220936 | ||
|
|
a5f05236f7 | ||
|
|
3cf109ec59 | ||
|
|
7c7f56a7fb | ||
|
|
bf8e0c09c9 | ||
|
|
a6b8bdf938 | ||
|
|
f51197a814 | ||
|
|
713e2feacf | ||
|
|
7f1faf6c12 | ||
|
|
ad6f020bd8 | ||
|
|
e052bd8c0f | ||
|
|
aef832b298 | ||
|
|
805a92e6f2 | ||
|
|
0c46151700 | ||
|
|
bfe89d6fa0 | ||
|
|
92b7e40fde | ||
|
|
d510a90214 | ||
|
|
a4fd06d603 | ||
|
|
cfe0bb6bb6 | ||
|
|
6aa9795c4f | ||
|
|
5041000a28 | ||
|
|
7cd939690a | ||
|
|
42cda9dd46 | ||
|
|
6917865bf3 | ||
|
|
2633fcb149 | ||
|
|
58de2b8d4a | ||
|
|
de72688b35 | ||
|
|
cbb367b1df | ||
|
|
18160475c4 | ||
|
|
31e9b3dd15 | ||
|
|
e28cfdf813 | ||
|
|
2751c26da7 | ||
|
|
45ca3c80e8 | ||
|
|
d85fbd0e99 | ||
|
|
acf1cb1dc4 | ||
|
|
c787b8b2dd | ||
|
|
9f2d8d2117 | ||
|
|
d42c3f6be1 | ||
|
|
4ea3e9efa6 | ||
|
|
2e1ea6ecaa | ||
|
|
6d2da77ce2 | ||
|
|
def4d902bf | ||
|
|
76a202c04e | ||
|
|
f7cfe946dc | ||
|
|
005b6373e2 | ||
|
|
d54e0fb3b2 | ||
|
|
bdd05e0ae0 | ||
|
|
1a346640db | ||
|
|
f5883070f8 | ||
|
|
adc23d5f96 | ||
|
|
a10a11b9d3 | ||
|
|
7d05a6ee8f | ||
|
|
464d817824 | ||
|
|
531324a9be | ||
|
|
6589eb8a8c | ||
|
|
a039e383cd | ||
|
|
80163ebcb5 | ||
|
|
a57818d93e | ||
|
|
94befe366a | ||
|
|
c95f97689b | ||
|
|
618eb5b909 | ||
|
|
841adda157 | ||
|
|
0035e31af8 | ||
|
|
eb75418be9 | ||
|
|
9959da05de | ||
|
|
c863c6a96d | ||
|
|
aff7970628 | ||
|
|
628f1feb36 | ||
|
|
ce3125afd5 | ||
|
|
f488652ba7 | ||
|
|
2318ed2919 | ||
|
|
b1b8be33d9 | ||
|
|
876f7eab81 | ||
|
|
7cfc8a0838 | ||
|
|
1f11b52511 | ||
|
|
526d4eb204 | ||
|
|
0a74cb31d5 | ||
|
|
10ed1b6292 | ||
|
|
4fec5816d6 | ||
|
|
0a0e9f3e0f | ||
|
|
58d95cc9bd | ||
|
|
3b6a9154dd | ||
|
|
d6dd2ff839 | ||
|
|
e57a6ba89f | ||
|
|
12ec2346ef | ||
|
|
1ec0df1069 | ||
|
|
91b3e4d282 | ||
|
|
d338d70492 | ||
|
|
011bb67351 | ||
|
|
d124627202 | ||
|
|
b0a8246a69 | ||
|
|
fd411b3cf6 | ||
|
|
04f38cf3f4 | ||
|
|
c0eddb10fd | ||
|
|
60ef0e6b4a | ||
|
|
48c60c01e2 | ||
|
|
eb2c442a01 | ||
|
|
c87fe7df48 | ||
|
|
5182a1dfb1 | ||
|
|
a32e7857b2 | ||
|
|
6acc205de0 | ||
|
|
f6e02d4bc7 | ||
|
|
e6fb39c182 | ||
|
|
e1d457c73e | ||
|
|
cd5df121a5 | ||
|
|
112ffed189 | ||
|
|
c49947dcf5 | ||
|
|
e1f1c374ea | ||
|
|
06a1508bfe | ||
|
|
5a5efee46b | ||
|
|
97ae517fbf | ||
|
|
44b813e459 | ||
|
|
539043f5e0 | ||
|
|
dbcace6847 | ||
|
|
c91a4ebcff | ||
|
|
b79c7e4528 | ||
|
|
035b274b70 | ||
|
|
9c6a254945 | ||
|
|
f31f2bedf4 | ||
|
|
756c257553 | ||
|
|
5255d0af8a | ||
|
|
af8a8a6b59 | ||
|
|
461ad25015 | ||
|
|
8838ae787d | ||
|
|
db75402ade | ||
|
|
1e85a140a3 | ||
|
|
c363282fdc | ||
|
|
5b0c48d29e |
2
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
2
.github/ISSUE_TEMPLATE/90_bug_report.yml
vendored
@@ -19,7 +19,7 @@ body:
|
||||
label: What did you expect to see?
|
||||
description: What did you expect to see/happen instead?
|
||||
validations:
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
id: steps
|
||||
attributes:
|
||||
|
||||
24
.github/workflows/latest.yaml
vendored
Normal file
24
.github/workflows/latest.yaml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: latest
|
||||
|
||||
on:
|
||||
release:
|
||||
types: [released]
|
||||
|
||||
jobs:
|
||||
update-latest:
|
||||
environment: release
|
||||
runs-on: linux
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ vars.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_ACCESS_TOKEN }}
|
||||
- name: Tag images as latest
|
||||
env:
|
||||
PUSH: "1"
|
||||
shell: bash
|
||||
run: |
|
||||
export "VERSION=${GITHUB_REF_NAME#v}"
|
||||
./scripts/tag_latest.sh
|
||||
23
.github/workflows/release.yaml
vendored
23
.github/workflows/release.yaml
vendored
@@ -213,24 +213,31 @@ jobs:
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
# TODO - consider replacing this action with a ps1 snippet to install
|
||||
# This actions seems to fail sometimes with "no tools in cache" but a re-run of the failed job clears it
|
||||
# https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe
|
||||
- name: "Install CUDA"
|
||||
uses: Jimver/cuda-toolkit@v0.2.14
|
||||
id: cuda-toolkit
|
||||
with:
|
||||
cuda: '11.3.1'
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: "Verify CUDA"
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
- name: "gather cuda dependencies"
|
||||
|
||||
174
.github/workflows/test.yaml
vendored
174
.github/workflows/test.yaml
vendored
@@ -9,7 +9,32 @@ on:
|
||||
- '!README.md'
|
||||
|
||||
jobs:
|
||||
changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
GENERATE: ${{ steps.changes.outputs.GENERATE }}
|
||||
GENERATE_CUDA: ${{ steps.changes.outputs.GENERATE_CUDA }}
|
||||
GENERATE_ROCM: ${{ steps.changes.outputs.GENERATE_ROCM }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- id: changes
|
||||
run: |
|
||||
changed() {
|
||||
git diff-tree -r --no-commit-id --name-only ${{ github.event.pull_request.base.sha }} ${{ github.event.pull_request.head.sha }} \
|
||||
| xargs python3 -c "import sys; print(any([x.startswith('$1') for x in sys.argv[1:]]))"
|
||||
}
|
||||
|
||||
{
|
||||
echo GENERATE=$(changed llm/)
|
||||
echo GENERATE_CUDA=$(changed llm/)
|
||||
echo GENERATE_ROCM=$(changed llm/)
|
||||
} >>$GITHUB_OUTPUT
|
||||
|
||||
generate:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -31,10 +56,12 @@ jobs:
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$gccpath=(get-command gcc).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:PATH="$gopath;$gccpath;$env:PATH"
|
||||
echo $env:PATH
|
||||
go generate -x ./...
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
name: "Windows Go Generate"
|
||||
@@ -44,8 +71,12 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: |
|
||||
llm/build/**/bin/*
|
||||
llm/build/**/*.a
|
||||
generate-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
cuda-version:
|
||||
@@ -73,12 +104,14 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: cuda-${{ matrix.cuda-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/bin/*
|
||||
generate-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
strategy:
|
||||
matrix:
|
||||
rocm-version:
|
||||
- '6.0'
|
||||
- '6.0.2'
|
||||
runs-on: linux
|
||||
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
|
||||
steps:
|
||||
@@ -102,7 +135,88 @@ jobs:
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: rocm-${{ matrix.rocm-version }}-libraries
|
||||
path: llm/llama.cpp/build/**/lib/*
|
||||
path: llm/build/**/lib/*
|
||||
|
||||
# ROCm generation step
|
||||
generate-windows-rocm:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_ROCM == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: "Install ROCm"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading AMD HIP Installer"
|
||||
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-23.Q4-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
|
||||
write-host "Installing AMD HIP"
|
||||
Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -Wait
|
||||
write-host "Completed AMD HIP"
|
||||
- name: "Verify ROCm"
|
||||
run: |
|
||||
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version
|
||||
- run: go get ./...
|
||||
- run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
$env:HIP_PATH=$(Resolve-Path 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | split-path | split-path)
|
||||
go generate -x ./...
|
||||
name: go generate
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
# CUDA generation step
|
||||
generate-windows-cuda:
|
||||
needs: [changes]
|
||||
if: ${{ needs.changes.outputs.GENERATE_CUDA == 'True' }}
|
||||
runs-on: windows
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- name: "Install CUDA"
|
||||
run: |
|
||||
$ErrorActionPreference = "Stop"
|
||||
write-host "downloading CUDA Installer"
|
||||
Invoke-WebRequest -Uri "https://developer.download.nvidia.com/compute/cuda/11.3.1/local_installers/cuda_11.3.1_465.89_win10.exe" -OutFile "${env:RUNNER_TEMP}\cuda-install.exe"
|
||||
write-host "Installing CUDA"
|
||||
Start-Process "${env:RUNNER_TEMP}\cuda-install.exe" -ArgumentList '-s' -NoNewWindow -Wait
|
||||
write-host "Completed CUDA"
|
||||
$cudaPath=((resolve-path "c:\Program Files\NVIDIA*\CUDA\v*\bin\nvcc.exe")[0].path | split-path | split-path)
|
||||
$cudaVer=($cudaPath | split-path -leaf ) -replace 'v(\d+).(\d+)', '$1_$2'
|
||||
echo "$cudaPath\bin" >> $env:GITHUB_PATH
|
||||
echo "CUDA_PATH=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_V${cudaVer}=$cudaPath" >> $env:GITHUB_ENV
|
||||
echo "CUDA_PATH_VX_Y=CUDA_PATH_V${cudaVer}" >> $env:GITHUB_ENV
|
||||
- name: "Verify CUDA"
|
||||
run: nvcc -V
|
||||
- run: go get ./...
|
||||
- name: go generate
|
||||
run: |
|
||||
$gopath=(get-command go).source | split-path -parent
|
||||
$cudabin=(get-command nvcc).source | split-path
|
||||
& "C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\Launch-VsDevShell.ps1"
|
||||
cd $env:GITHUB_WORKSPACE
|
||||
$env:CMAKE_SYSTEM_VERSION="10.0.22621.0"
|
||||
$env:PATH="$gopath;$cudabin;$env:PATH"
|
||||
$env:OLLAMA_SKIP_CPU_GENERATE="1"
|
||||
go generate -x ./...
|
||||
env:
|
||||
OLLAMA_SKIP_CPU_GENERATE: '1'
|
||||
# TODO - do we need any artifacts?
|
||||
|
||||
|
||||
lint:
|
||||
strategy:
|
||||
matrix:
|
||||
@@ -128,21 +242,28 @@ jobs:
|
||||
go-version: '1.22'
|
||||
cache: false
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm/build/linux/$ARCH/stub/bin/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
|
||||
touch llm/llama.cpp/ggml-metal.metal
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
|
||||
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- uses: golangci/golangci-lint-action@v3
|
||||
- uses: golangci/golangci-lint-action@v4
|
||||
with:
|
||||
args: --timeout 8m0s
|
||||
test:
|
||||
needs: generate
|
||||
strategy:
|
||||
matrix:
|
||||
os: [ubuntu-latest, macos-latest, windows-2019]
|
||||
@@ -156,6 +277,7 @@ jobs:
|
||||
env:
|
||||
GOARCH: ${{ matrix.arch }}
|
||||
CGO_ENABLED: '1'
|
||||
OLLAMA_CPU_TARGET: "static"
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
@@ -165,10 +287,26 @@ jobs:
|
||||
go-version: '1.22'
|
||||
cache: true
|
||||
- run: go get
|
||||
- uses: actions/download-artifact@v4
|
||||
with:
|
||||
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
|
||||
path: llm/llama.cpp/build
|
||||
- run: |
|
||||
case ${{ matrix.arch }} in
|
||||
amd64) echo ARCH=x86_64 ;;
|
||||
arm64) echo ARCH=arm64 ;;
|
||||
esac >>$GITHUB_ENV
|
||||
shell: bash
|
||||
- run: |
|
||||
mkdir -p llm/build/linux/$ARCH/stub/bin/
|
||||
touch llm//build/linux/$ARCH/stub/bin/stub.so
|
||||
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/darwin/$ARCH/stub/bin/
|
||||
touch llm/build/darwin/$ARCH/stub/bin/stub.dylib
|
||||
touch llm/ggml-metal.metal
|
||||
if: ${{ startsWith(matrix.os, 'macos-') }}
|
||||
- run: |
|
||||
mkdir -p llm/build/windows/$ARCH/stub/stub/bin/
|
||||
touch llm/build/windows/$ARCH/stub/stub/bin/stub.dll
|
||||
if: ${{ startsWith(matrix.os, 'windows-') }}
|
||||
- run: go generate ./...
|
||||
- run: go build
|
||||
- run: go test -v ./...
|
||||
- uses: actions/upload-artifact@v4
|
||||
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -10,4 +10,5 @@ ggml-metal.metal
|
||||
*.exe
|
||||
.idea
|
||||
test_data
|
||||
*.crt
|
||||
*.crt
|
||||
llm/build
|
||||
@@ -15,13 +15,3 @@ linters:
|
||||
- misspell
|
||||
- nilerr
|
||||
- unused
|
||||
linters-settings:
|
||||
errcheck:
|
||||
# exclude the following functions since we don't generally
|
||||
# need to be concerned with the returned errors
|
||||
exclude-functions:
|
||||
- encoding/binary.Read
|
||||
- (*os.File).Seek
|
||||
- (*bufio.Writer).WriteString
|
||||
- (*github.com/spf13/pflag.FlagSet).Set
|
||||
- (*github.com/ollama/ollama/llm.readSeekOffset).Seek
|
||||
|
||||
27
Dockerfile
27
Dockerfile
@@ -2,7 +2,7 @@ ARG GOLANG_VERSION=1.22.1
|
||||
ARG CMAKE_VERSION=3.22.1
|
||||
# this CUDA_VERSION corresponds with the one specified in docs/gpu.md
|
||||
ARG CUDA_VERSION=11.3.1
|
||||
ARG ROCM_VERSION=6.0
|
||||
ARG ROCM_VERSION=6.0.2
|
||||
|
||||
# Copy the minimal context we need to run the generate scripts
|
||||
FROM scratch AS llm-code
|
||||
@@ -61,6 +61,8 @@ ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS static-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
|
||||
@@ -68,28 +70,33 @@ RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
|
||||
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
|
||||
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
|
||||
FROM --platform=linux/arm64 centos:7 AS cpu-builder-arm64
|
||||
ARG CMAKE_VERSION
|
||||
ARG GOLANG_VERSION
|
||||
COPY ./scripts/rh_linux_deps.sh /
|
||||
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
|
||||
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
|
||||
COPY --from=llm-code / /go/src/github.com/ollama/ollama/
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
|
||||
ARG OLLAMA_CUSTOM_CPU_DEFS
|
||||
ARG CGO_CFLAGS
|
||||
WORKDIR /go/src/github.com/ollama/ollama/llm/generate
|
||||
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS static-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="static" sh gen_linux.sh
|
||||
FROM --platform=linux/arm64 cpu-builder-arm64 AS cpu-build-arm64
|
||||
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
|
||||
|
||||
|
||||
# Intermediate stage used for ./scripts/build_linux.sh
|
||||
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
|
||||
ENV CGO_ENABLED 1
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
COPY --from=static-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=rocm-build-amd64 /go/src/github.com/ollama/ollama/dist/deps/ ./dist/deps/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
@@ -101,8 +108,8 @@ ENV CGO_ENABLED 1
|
||||
ARG GOLANG_VERSION
|
||||
WORKDIR /go/src/github.com/ollama/ollama
|
||||
COPY . .
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
|
||||
RUN mkdir -p /go/src/github.com/ollama/ollama/dist/deps/
|
||||
COPY --from=static-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
COPY --from=cuda-build-arm64 /go/src/github.com/ollama/ollama/llm/build/linux/ llm/build/linux/
|
||||
ARG GOFLAGS
|
||||
ARG CGO_CFLAGS
|
||||
RUN go build -trimpath .
|
||||
|
||||
@@ -259,6 +259,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
|
||||
### Web & Desktop
|
||||
|
||||
- [Lollms-Webui](https://github.com/ParisNeo/lollms-webui)
|
||||
- [LibreChat](https://github.com/danny-avila/LibreChat)
|
||||
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
|
||||
- [Enchanted (macOS native)](https://github.com/AugustDev/enchanted)
|
||||
@@ -289,6 +290,8 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [AnythingLLM (Docker + MacOs/Windows/Linux native app)](https://github.com/Mintplex-Labs/anything-llm)
|
||||
- [Ollama Basic Chat: Uses HyperDiv Reactive UI](https://github.com/rapidarchitect/ollama_basic_chat)
|
||||
- [Ollama-chats RPG](https://github.com/drazdra/ollama-chats)
|
||||
- [ChatOllama: Open Source Chatbot based on Ollama with Knowledge Bases](https://github.com/sugarforever/chat-ollama)
|
||||
- [CRAG Ollama Chat: Simple Web Search with Corrective RAG](https://github.com/Nagi-ovo/CRAG-Ollama-Chat)
|
||||
|
||||
### Terminal
|
||||
|
||||
@@ -313,6 +316,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
### Database
|
||||
|
||||
- [MindsDB](https://github.com/mindsdb/mindsdb/blob/staging/mindsdb/integrations/handlers/ollama_handler/README.md)
|
||||
- [chromem-go](https://github.com/philippgille/chromem-go/blob/v0.5.0/embed_ollama.go) with [example](https://github.com/philippgille/chromem-go/tree/v0.5.0/examples/rag-wikipedia-ollama)
|
||||
|
||||
### Package managers
|
||||
|
||||
@@ -371,3 +375,4 @@ See the [API documentation](./docs/api.md) for all endpoints.
|
||||
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)
|
||||
- [Page Assist](https://github.com/n4ze3m/page-assist) (Chrome Extension)
|
||||
- [AI Telegram Bot](https://github.com/tusharhero/aitelegrambot) (Telegram bot using Ollama in backend)
|
||||
- [AI ST Completion](https://github.com/yaroslavyaroslav/OpenAI-sublime-text) (Sublime Text 4 AI assistant plugin with Ollama support)
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
@@ -83,6 +84,28 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
io.Copy(logFile, stderr) //nolint:errcheck
|
||||
}()
|
||||
|
||||
// Re-wire context done behavior to attempt a graceful shutdown of the server
|
||||
cmd.Cancel = func() error {
|
||||
if cmd.Process != nil {
|
||||
cmd.Process.Signal(os.Interrupt) //nolint:errcheck
|
||||
tick := time.NewTicker(10 * time.Millisecond)
|
||||
defer tick.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-tick.C:
|
||||
// OS agnostic "is it still running"
|
||||
if proc, err := os.FindProcess(int(cmd.Process.Pid)); err != nil || errors.Is(proc.Signal(syscall.Signal(0)), os.ErrProcessDone) {
|
||||
return nil //nolint:nilerr
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
slog.Warn("graceful server shutdown timeout, killing", "pid", cmd.Process.Pid)
|
||||
cmd.Process.Kill() //nolint:errcheck
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// run the command and wait for it to finish
|
||||
if err := cmd.Start(); err != nil {
|
||||
return done, fmt.Errorf("failed to start server %w", err)
|
||||
@@ -105,7 +128,7 @@ func SpawnServer(ctx context.Context, command string) (chan int, error) {
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
slog.Debug(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
slog.Info(fmt.Sprintf("server shutdown with exit code %d", code))
|
||||
done <- code
|
||||
return
|
||||
default:
|
||||
|
||||
23
cmd/cmd.go
23
cmd/cmd.go
@@ -213,7 +213,10 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
if _, err := io.Copy(hash, bin); err != nil {
|
||||
return "", err
|
||||
}
|
||||
bin.Seek(0, io.SeekStart)
|
||||
|
||||
if _, err := bin.Seek(0, io.SeekStart); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
digest := fmt.Sprintf("sha256:%x", hash.Sum(nil))
|
||||
if err = client.CreateBlob(cmd.Context(), digest, bin); err != nil {
|
||||
@@ -223,6 +226,14 @@ func createBlob(cmd *cobra.Command, client *api.Client, path string) (string, er
|
||||
}
|
||||
|
||||
func RunHandler(cmd *cobra.Command, args []string) error {
|
||||
if os.Getenv("OLLAMA_MODELS") != "" {
|
||||
return errors.New("OLLAMA_MODELS must only be set for 'ollama serve'")
|
||||
}
|
||||
|
||||
if err := checkServerHeartbeat(cmd, args); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := api.ClientFromEnvironment()
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -948,11 +959,10 @@ func NewCLI() *cobra.Command {
|
||||
showCmd.Flags().Bool("system", false, "Show system message of a model")
|
||||
|
||||
runCmd := &cobra.Command{
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
PreRunE: checkServerHeartbeat,
|
||||
RunE: RunHandler,
|
||||
Use: "run MODEL [PROMPT]",
|
||||
Short: "Run a model",
|
||||
Args: cobra.MinimumNArgs(1),
|
||||
RunE: RunHandler,
|
||||
}
|
||||
|
||||
runCmd.Flags().Bool("verbose", false, "Show timings for response")
|
||||
@@ -973,6 +983,7 @@ Environment Variables:
|
||||
OLLAMA_ORIGINS A comma separated list of allowed origins.
|
||||
OLLAMA_MODELS The path to the models directory (default is "~/.ollama/models")
|
||||
OLLAMA_KEEP_ALIVE The duration that models stay loaded in memory (default is "5m")
|
||||
OLLAMA_DEBUG Set to 1 to enable additional debug logging
|
||||
`)
|
||||
|
||||
pullCmd := &cobra.Command{
|
||||
|
||||
@@ -295,10 +295,14 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
|
||||
opts.WordWrap = false
|
||||
fmt.Println("Set 'nowordwrap' mode.")
|
||||
case "verbose":
|
||||
cmd.Flags().Set("verbose", "true")
|
||||
if err := cmd.Flags().Set("verbose", "true"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'verbose' mode.")
|
||||
case "quiet":
|
||||
cmd.Flags().Set("verbose", "false")
|
||||
if err := cmd.Flags().Set("verbose", "false"); err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("Set 'quiet' mode.")
|
||||
case "format":
|
||||
if len(args) < 3 || args[2] != "json" {
|
||||
|
||||
@@ -13,7 +13,9 @@ import (
|
||||
"regexp"
|
||||
"slices"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/x448/float16"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/ollama/ollama/convert/sentencepiece"
|
||||
@@ -33,6 +35,15 @@ type Params struct {
|
||||
RopeFreqBase float64 `json:"rope_theta"`
|
||||
BoSTokenID int `json:"bos_token_id"`
|
||||
EoSTokenID int `json:"eos_token_id"`
|
||||
HeadDimension int `json:"head_dim"`
|
||||
PaddingTokenID int `json:"pad_token_id"`
|
||||
|
||||
ByteOrder
|
||||
}
|
||||
|
||||
type ByteOrder interface {
|
||||
binary.ByteOrder
|
||||
binary.AppendByteOrder
|
||||
}
|
||||
|
||||
type MetaData struct {
|
||||
@@ -41,27 +52,43 @@ type MetaData struct {
|
||||
Offsets []int `mapstructure:"data_offsets"`
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
type ModelArch interface {
|
||||
GetTensors() error
|
||||
LoadVocab() error
|
||||
WriteGGUF() (string, error)
|
||||
}
|
||||
|
||||
type ModelData struct {
|
||||
Path string
|
||||
Name string
|
||||
Params *Params
|
||||
Vocab *Vocab
|
||||
Tensors []llm.Tensor
|
||||
}
|
||||
|
||||
func ReadSafeTensors(fn string, offset uint64, params *Params) ([]llm.Tensor, uint64, error) {
|
||||
f, err := os.Open(fn)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
var jsonSize uint64
|
||||
binary.Read(f, binary.LittleEndian, &jsonSize)
|
||||
if err := binary.Read(f, binary.LittleEndian, &jsonSize); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
buf := make([]byte, jsonSize)
|
||||
_, err = io.ReadFull(f, buf)
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
d := json.NewDecoder(bytes.NewBuffer(buf))
|
||||
d.UseNumber()
|
||||
var parsed map[string]interface{}
|
||||
if err = d.Decode(&parsed); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var keys []string
|
||||
@@ -78,7 +105,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
vals := parsed[k].(map[string]interface{})
|
||||
var data MetaData
|
||||
if err = mapstructure.Decode(vals, &data); err != nil {
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var size uint64
|
||||
@@ -100,7 +127,7 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
ggufName, err := GetTensorName(k)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, 0, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
shape := []uint64{0, 0, 0, 0}
|
||||
@@ -109,14 +136,22 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
}
|
||||
|
||||
t := llm.Tensor{
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
FileName: fn,
|
||||
OffsetPadding: 8 + jsonSize,
|
||||
FileOffsets: []uint64{uint64(data.Offsets[0]), uint64(data.Offsets[1])},
|
||||
Name: ggufName,
|
||||
Kind: kind,
|
||||
Offset: offset,
|
||||
Shape: shape[:],
|
||||
}
|
||||
|
||||
t.WriterTo = safetensorWriterTo{
|
||||
t: &t,
|
||||
params: params,
|
||||
bo: params.ByteOrder,
|
||||
filename: fn,
|
||||
start: uint64(data.Offsets[0]),
|
||||
end: uint64(data.Offsets[1]),
|
||||
padding: 8 + jsonSize,
|
||||
}
|
||||
|
||||
slog.Debug(fmt.Sprintf("%v", t))
|
||||
tensors = append(tensors, t)
|
||||
offset += size
|
||||
@@ -124,21 +159,21 @@ func ReadSafeTensors(fn string, offset uint64) ([]llm.Tensor, uint64, error) {
|
||||
return tensors, offset, nil
|
||||
}
|
||||
|
||||
func GetSafeTensors(dirpath string) ([]llm.Tensor, error) {
|
||||
func GetSafeTensors(dirpath string, params *Params) ([]llm.Tensor, error) {
|
||||
var tensors []llm.Tensor
|
||||
files, err := filepath.Glob(filepath.Join(dirpath, "/model-*.safetensors"))
|
||||
if err != nil {
|
||||
return []llm.Tensor{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var offset uint64
|
||||
for _, f := range files {
|
||||
var t []llm.Tensor
|
||||
var err error
|
||||
t, offset, err = ReadSafeTensors(f, offset)
|
||||
t, offset, err = ReadSafeTensors(f, offset, params)
|
||||
if err != nil {
|
||||
slog.Error("%v", err)
|
||||
return []llm.Tensor{}, err
|
||||
return nil, err
|
||||
}
|
||||
tensors = append(tensors, t...)
|
||||
}
|
||||
@@ -160,6 +195,7 @@ func GetParams(dirpath string) (*Params, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.ByteOrder = binary.LittleEndian
|
||||
return ¶ms, nil
|
||||
}
|
||||
|
||||
@@ -171,7 +207,7 @@ type Vocab struct {
|
||||
Types []int32
|
||||
}
|
||||
|
||||
func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
func LoadSentencePieceTokens(dirpath string, vocabSize int) (*Vocab, error) {
|
||||
slog.Info(fmt.Sprintf("reading vocab from %s", filepath.Join(dirpath, "tokenizer.model")))
|
||||
in, err := os.ReadFile(filepath.Join(dirpath, "tokenizer.model"))
|
||||
if err != nil {
|
||||
@@ -196,6 +232,14 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
v.Tokens = append(v.Tokens, p.GetPiece())
|
||||
v.Scores = append(v.Scores, p.GetScore())
|
||||
t := p.GetType()
|
||||
switch t {
|
||||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN:
|
||||
case sentencepiece.ModelProto_SentencePiece_CONTROL:
|
||||
case sentencepiece.ModelProto_SentencePiece_UNUSED:
|
||||
case sentencepiece.ModelProto_SentencePiece_BYTE:
|
||||
default:
|
||||
t = sentencepiece.ModelProto_SentencePiece_NORMAL
|
||||
}
|
||||
v.Types = append(v.Types, int32(t))
|
||||
}
|
||||
|
||||
@@ -243,6 +287,16 @@ func LoadTokens(dirpath string) (*Vocab, error) {
|
||||
}
|
||||
slog.Info(fmt.Sprintf("vocab size w/ extra tokens: %d", len(v.Tokens)))
|
||||
|
||||
if vocabSize > len(v.Tokens) {
|
||||
missingTokens := vocabSize - len(v.Tokens)
|
||||
slog.Warn(fmt.Sprintf("vocab is missing %d tokens", missingTokens))
|
||||
for cnt := 0; cnt < missingTokens; cnt++ {
|
||||
v.Tokens = append(v.Tokens, fmt.Sprintf("<dummy%05d>", cnt+1))
|
||||
v.Scores = append(v.Scores, -1)
|
||||
v.Types = append(v.Types, int32(llm.GGUFTokenUserDefined))
|
||||
}
|
||||
}
|
||||
|
||||
return v, nil
|
||||
}
|
||||
|
||||
@@ -279,53 +333,102 @@ func GetTensorName(n string) (string, error) {
|
||||
return "", fmt.Errorf("couldn't find a layer name for '%s'", n)
|
||||
}
|
||||
|
||||
func WriteGGUF(name string, tensors []llm.Tensor, params *Params, vocab *Vocab) (string, error) {
|
||||
c := llm.ContainerGGUF{
|
||||
ByteOrder: binary.LittleEndian,
|
||||
}
|
||||
type safetensorWriterTo struct {
|
||||
t *llm.Tensor
|
||||
|
||||
m := llm.NewGGUFModel(&c)
|
||||
m.Tensors = tensors
|
||||
m.KV["general.architecture"] = "llama"
|
||||
m.KV["general.name"] = name
|
||||
m.KV["llama.context_length"] = uint32(params.ContextSize)
|
||||
m.KV["llama.embedding_length"] = uint32(params.HiddenSize)
|
||||
m.KV["llama.block_count"] = uint32(params.HiddenLayers)
|
||||
m.KV["llama.feed_forward_length"] = uint32(params.IntermediateSize)
|
||||
m.KV["llama.rope.dimension_count"] = uint32(128)
|
||||
m.KV["llama.attention.head_count"] = uint32(params.AttentionHeads)
|
||||
m.KV["llama.attention.head_count_kv"] = uint32(params.KeyValHeads)
|
||||
m.KV["llama.attention.layer_norm_rms_epsilon"] = float32(params.NormEPS)
|
||||
m.KV["llama.rope.freq_base"] = float32(params.RopeFreqBase)
|
||||
m.KV["general.file_type"] = uint32(1)
|
||||
m.KV["tokenizer.ggml.model"] = "llama"
|
||||
params *Params
|
||||
bo ByteOrder
|
||||
|
||||
m.KV["tokenizer.ggml.tokens"] = vocab.Tokens
|
||||
m.KV["tokenizer.ggml.scores"] = vocab.Scores
|
||||
m.KV["tokenizer.ggml.token_type"] = vocab.Types
|
||||
filename string
|
||||
|
||||
m.KV["tokenizer.ggml.bos_token_id"] = uint32(params.BoSTokenID)
|
||||
m.KV["tokenizer.ggml.eos_token_id"] = uint32(params.EoSTokenID)
|
||||
m.KV["tokenizer.ggml.unknown_token_id"] = uint32(0)
|
||||
m.KV["tokenizer.ggml.add_bos_token"] = true
|
||||
m.KV["tokenizer.ggml.add_eos_token"] = false
|
||||
start, end, padding uint64
|
||||
handler func(w io.Writer, r safetensorWriterTo, f *os.File) error
|
||||
}
|
||||
|
||||
// llamacpp sets the chat template, however we don't need to set it since we pass it in through a layer
|
||||
// m.KV["tokenizer.chat_template"] = "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}" // XXX removeme
|
||||
|
||||
c.V3.NumTensor = uint64(len(tensors))
|
||||
c.V3.NumKV = uint64(len(m.KV))
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
func (r safetensorWriterTo) WriteTo(w io.Writer) (n int64, err error) {
|
||||
f, err := os.Open(r.filename)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return 0, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
err = m.Encode(f)
|
||||
if err != nil {
|
||||
return "", err
|
||||
if _, err = f.Seek(int64(r.padding+r.start), 0); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
// use the handler if one is present
|
||||
if r.handler != nil {
|
||||
return 0, r.handler(w, r, f)
|
||||
}
|
||||
|
||||
remaining := r.end - r.start
|
||||
|
||||
bufSize := uint64(10240)
|
||||
var finished bool
|
||||
for {
|
||||
data := make([]byte, min(bufSize, remaining))
|
||||
|
||||
b, err := io.ReadFull(f, data)
|
||||
remaining -= uint64(b)
|
||||
|
||||
if err == io.EOF || remaining <= 0 {
|
||||
finished = true
|
||||
} else if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// convert bfloat16 -> ieee float32
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
switch r.t.Kind {
|
||||
case 0:
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
case 1:
|
||||
// convert float32 -> float16
|
||||
tempBuf := make([]uint16, len(data)/2)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
if err := binary.Write(w, binary.LittleEndian, tempBuf); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if finished {
|
||||
break
|
||||
}
|
||||
}
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func GetModelArchFromParams(name, dirPath string, params *Params) (ModelArch, error) {
|
||||
switch len(params.Architectures) {
|
||||
case 0:
|
||||
return nil, fmt.Errorf("No architecture specified to convert")
|
||||
case 1:
|
||||
switch params.Architectures[0] {
|
||||
case "MistralForCausalLM":
|
||||
return &MistralModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
case "GemmaForCausalLM":
|
||||
return &GemmaModel{
|
||||
ModelData{
|
||||
Name: name,
|
||||
Path: dirPath,
|
||||
Params: params,
|
||||
},
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("Models based on '%s' are not yet supported", params.Architectures[0])
|
||||
}
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("Unknown error")
|
||||
}
|
||||
|
||||
136
convert/gemma.go
Normal file
136
convert/gemma.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"log/slog"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type GemmaModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func gemmaLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
slog.Debug(fmt.Sprintf("converting '%s'", r.t.Name))
|
||||
|
||||
data := make([]byte, r.end-r.start)
|
||||
if err := binary.Read(f, r.bo, data); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tDataF32 := bfloat16.DecodeFloat32(data)
|
||||
|
||||
var err error
|
||||
tDataF32, err = addOnes(tDataF32, int(r.t.Shape[0]))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := binary.Write(w, r.bo, tDataF32); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addOnes(data []float32, vectorSize int) ([]float32, error) {
|
||||
n := tensor.New(tensor.WithShape(vectorSize), tensor.WithBacking(data))
|
||||
ones := tensor.Ones(tensor.Float32, vectorSize)
|
||||
|
||||
var err error
|
||||
n, err = n.Add(ones)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
newN, err := native.SelectF32(n, 0)
|
||||
if err != nil {
|
||||
return []float32{}, err
|
||||
}
|
||||
|
||||
var fullTensor []float32
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
for _, l := range t {
|
||||
if strings.HasSuffix(l.Name, "norm.weight") {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = gemmaLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *GemmaModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "gemma",
|
||||
"general.name": m.Name,
|
||||
"gemma.context_length": uint32(m.Params.ContextSize),
|
||||
"gemma.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"gemma.block_count": uint32(m.Params.HiddenLayers),
|
||||
"gemma.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"gemma.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"gemma.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"gemma.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"gemma.attention.key_length": uint32(m.Params.HeadDimension),
|
||||
"gemma.attention.value_length": uint32(m.Params.HeadDimension),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.padding_token_id": uint32(m.Params.PaddingTokenID),
|
||||
"tokenizer.ggml.unknown_token_id": uint32(3),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
174
convert/mistral.go
Normal file
174
convert/mistral.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package convert
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/d4l3k/go-bfloat16"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
|
||||
"github.com/ollama/ollama/llm"
|
||||
)
|
||||
|
||||
type MistralModel struct {
|
||||
ModelData
|
||||
}
|
||||
|
||||
func mistralLayerHandler(w io.Writer, r safetensorWriterTo, f *os.File) error {
|
||||
layerSize := r.end - r.start
|
||||
|
||||
var err error
|
||||
tData := make([]uint16, layerSize/2)
|
||||
if err = binary.Read(f, r.bo, tData); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var heads uint32
|
||||
if strings.Contains(r.t.Name, "attn_q") {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
} else if strings.Contains(r.t.Name, "attn_k") {
|
||||
heads = uint32(r.params.KeyValHeads)
|
||||
if heads == 0 {
|
||||
heads = uint32(r.params.AttentionHeads)
|
||||
}
|
||||
} else {
|
||||
return fmt.Errorf("unknown layer type")
|
||||
}
|
||||
|
||||
tData, err = repack(tData, int(heads), r.t.Shape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var buf []byte
|
||||
for _, n := range tData {
|
||||
buf = r.bo.AppendUint16(buf, n)
|
||||
}
|
||||
|
||||
tempBuf := make([]uint16, len(tData))
|
||||
tDataF32 := bfloat16.DecodeFloat32(buf)
|
||||
for cnt, v := range tDataF32 {
|
||||
tDataF16 := float16.Fromfloat32(v)
|
||||
tempBuf[cnt] = uint16(tDataF16)
|
||||
}
|
||||
|
||||
if err = binary.Write(w, r.bo, tempBuf); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func repack(data []uint16, heads int, shape []uint64) ([]uint16, error) {
|
||||
n := tensor.New(tensor.WithShape(int(shape[0]), int(shape[1])), tensor.WithBacking(data))
|
||||
origShape := n.Shape().Clone()
|
||||
|
||||
// reshape the tensor and swap axes 1 and 2 to unpack the layer for gguf
|
||||
if err := n.Reshape(heads, 2, origShape[0]/heads/2, origShape[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.T(0, 2, 1, 3); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Reshape(origShape...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := n.Transpose(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newN, err := native.SelectU16(n, 1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var fullTensor []uint16
|
||||
for _, v := range newN {
|
||||
fullTensor = append(fullTensor, v...)
|
||||
}
|
||||
return fullTensor, nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) GetTensors() error {
|
||||
t, err := GetSafeTensors(m.Path, m.Params)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.Tensors = []llm.Tensor{}
|
||||
|
||||
pattern := `^blk\.[0-9]+\.attn_(?P<layer>q|k)\.weight$`
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, l := range t {
|
||||
matches := re.FindAllStringSubmatch(l.Name, -1)
|
||||
if len(matches) > 0 {
|
||||
wt := l.WriterTo.(safetensorWriterTo)
|
||||
wt.handler = mistralLayerHandler
|
||||
l.WriterTo = wt
|
||||
}
|
||||
m.Tensors = append(m.Tensors, l)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) LoadVocab() error {
|
||||
v, err := LoadSentencePieceTokens(m.Path, m.Params.VocabSize)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.Vocab = v
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *MistralModel) WriteGGUF() (string, error) {
|
||||
kv := llm.KV{
|
||||
"general.architecture": "llama",
|
||||
"general.name": m.Name,
|
||||
"llama.context_length": uint32(m.Params.ContextSize),
|
||||
"llama.embedding_length": uint32(m.Params.HiddenSize),
|
||||
"llama.block_count": uint32(m.Params.HiddenLayers),
|
||||
"llama.feed_forward_length": uint32(m.Params.IntermediateSize),
|
||||
"llama.rope.dimension_count": uint32(m.Params.HiddenSize / m.Params.AttentionHeads),
|
||||
"llama.attention.head_count": uint32(m.Params.AttentionHeads),
|
||||
"llama.attention.head_count_kv": uint32(m.Params.KeyValHeads),
|
||||
"llama.attention.layer_norm_rms_epsilon": float32(m.Params.NormEPS),
|
||||
"llama.rope.freq_base": float32(m.Params.RopeFreqBase),
|
||||
"general.file_type": uint32(1),
|
||||
"tokenizer.ggml.model": "llama",
|
||||
|
||||
"tokenizer.ggml.tokens": m.Vocab.Tokens,
|
||||
"tokenizer.ggml.scores": m.Vocab.Scores,
|
||||
"tokenizer.ggml.token_type": m.Vocab.Types,
|
||||
|
||||
"tokenizer.ggml.bos_token_id": uint32(m.Params.BoSTokenID),
|
||||
"tokenizer.ggml.eos_token_id": uint32(m.Params.EoSTokenID),
|
||||
"tokenizer.ggml.add_bos_token": true,
|
||||
"tokenizer.ggml.add_eos_token": false,
|
||||
"tokenizer.ggml.unknown_token_id": uint32(0),
|
||||
}
|
||||
|
||||
f, err := os.CreateTemp("", "ollama-gguf")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
mod := llm.NewGGUFV3(m.Params.ByteOrder)
|
||||
if err := mod.Encode(f, kv, m.Tensors); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return f.Name(), nil
|
||||
}
|
||||
@@ -71,9 +71,12 @@ More examples are available in the [examples directory](../examples).
|
||||
|
||||
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
|
||||
|
||||
- Option 1: view a model's data:
|
||||
1. Go to a particular model page (e.g. https://ollama.com/library/llama2)
|
||||
2. There is a table that displays the model's different components
|
||||
- Option 1: view a details page from a model's tags page:
|
||||
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
|
||||
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
|
||||
3. Scroll down to "Layers"
|
||||
- Note: if the [`FROM` instruction](#from-required) is not present,
|
||||
it means the model was created from a local file
|
||||
- Option 2: use `ollama show` to print the `Modelfile` for any local models like so:
|
||||
|
||||
```bash
|
||||
@@ -212,6 +215,7 @@ MESSAGE <role> <message>
|
||||
| user | An example message of what the user could have asked. |
|
||||
| assistant | An example message of how the model should respond. |
|
||||
|
||||
|
||||
#### Example conversation
|
||||
|
||||
```modelfile
|
||||
@@ -223,6 +227,7 @@ MESSAGE user Is Ontario in Canada?
|
||||
MESSAGE assistant yes
|
||||
```
|
||||
|
||||
|
||||
## Notes
|
||||
|
||||
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
|
||||
|
||||
@@ -76,3 +76,10 @@ install script which version to install.
|
||||
```sh
|
||||
curl -fsSL https://ollama.com/install.sh | OLLAMA_VERSION="0.1.29" sh
|
||||
```
|
||||
|
||||
## Linux tmp noexec
|
||||
|
||||
If your system is configured with the "noexec" flag where Ollama stores its
|
||||
temporary executable files, you can specify an alternate location by setting
|
||||
OLLAMA_TMPDIR to a location writable by the user ollama runs as. For example
|
||||
OLLAMA_TMPDIR=/usr/share/ollama/
|
||||
@@ -6,11 +6,15 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
Byte = 1
|
||||
Byte = 1
|
||||
|
||||
KiloByte = Byte * 1000
|
||||
MegaByte = KiloByte * 1000
|
||||
GigaByte = MegaByte * 1000
|
||||
TeraByte = GigaByte * 1000
|
||||
|
||||
KibiByte = Byte * 1024
|
||||
MebiByte = KibiByte * 1024
|
||||
)
|
||||
|
||||
func HumanBytes(b int64) string {
|
||||
@@ -45,3 +49,14 @@ func HumanBytes(b int64) string {
|
||||
return fmt.Sprintf("%d %s", int(value), unit)
|
||||
}
|
||||
}
|
||||
|
||||
func HumanBytes2(b int64) string {
|
||||
switch {
|
||||
case b >= MebiByte:
|
||||
return fmt.Sprintf("%.1f MiB", float64(b)/MebiByte)
|
||||
case b >= KibiByte:
|
||||
return fmt.Sprintf("%.1f KiB", float64(b)/KibiByte)
|
||||
default:
|
||||
return fmt.Sprintf("%d B", b)
|
||||
}
|
||||
}
|
||||
|
||||
30
go.mod
30
go.mod
@@ -9,8 +9,8 @@ require (
|
||||
github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1
|
||||
github.com/emirpasic/gods v1.18.1
|
||||
github.com/gin-gonic/gin v1.9.1
|
||||
github.com/golang/protobuf v1.5.0
|
||||
github.com/google/uuid v1.0.0
|
||||
github.com/golang/protobuf v1.5.0 // indirect
|
||||
github.com/google/uuid v1.6.0
|
||||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/olekukonko/tablewriter v0.0.5
|
||||
github.com/spf13/cobra v1.7.0
|
||||
@@ -19,23 +19,35 @@ require (
|
||||
golang.org/x/sync v0.3.0
|
||||
)
|
||||
|
||||
require github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
require (
|
||||
github.com/minio/minio-go/v7 v7.0.69
|
||||
github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9
|
||||
kr.dev/diff v0.3.0
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/apache/arrow/go/arrow v0.0.0-20201229220542-30ce2eb5d4dc // indirect
|
||||
github.com/chewxy/hm v1.0.0 // indirect
|
||||
github.com/chewxy/math32 v1.0.8 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dustin/go-humanize v1.0.1 // indirect
|
||||
github.com/gogo/protobuf v1.3.2 // indirect
|
||||
github.com/google/flatbuffers v1.12.0 // indirect
|
||||
github.com/klauspost/compress v1.17.6 // indirect
|
||||
github.com/mattn/go-runewidth v0.0.14 // indirect
|
||||
github.com/minio/md5-simd v1.1.2 // indirect
|
||||
github.com/minio/sha256-simd v1.0.1 // indirect
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e // indirect
|
||||
github.com/pkg/errors v0.9.1 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/rogpeppe/go-internal v1.8.1 // indirect
|
||||
github.com/rs/xid v1.5.0 // indirect
|
||||
github.com/xtgo/set v1.0.0 // indirect
|
||||
go4.org/unsafe/assume-no-moving-gc v0.0.0-20231121144256-b99613f794b6 // indirect
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 // indirect
|
||||
gonum.org/v1/gonum v0.8.2 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
gorgonia.org/vecf32 v0.9.0 // indirect
|
||||
gorgonia.org/vecf64 v0.9.0 // indirect
|
||||
)
|
||||
@@ -53,7 +65,7 @@ require (
|
||||
github.com/google/go-cmp v0.5.9 // indirect
|
||||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/json-iterator/go v1.1.12 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 // indirect
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 // indirect
|
||||
github.com/leodido/go-urn v1.2.4 // indirect
|
||||
github.com/mattn/go-isatty v0.0.19 // indirect
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
|
||||
@@ -63,12 +75,12 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.11 // indirect
|
||||
golang.org/x/arch v0.3.0 // indirect
|
||||
golang.org/x/crypto v0.14.0
|
||||
golang.org/x/crypto v0.19.0
|
||||
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
|
||||
golang.org/x/net v0.17.0 // indirect
|
||||
golang.org/x/sys v0.13.0
|
||||
golang.org/x/term v0.13.0
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
golang.org/x/net v0.21.0 // indirect
|
||||
golang.org/x/sys v0.17.0
|
||||
golang.org/x/term v0.17.0
|
||||
golang.org/x/text v0.14.0 // indirect
|
||||
google.golang.org/protobuf v1.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
||||
51
go.sum
51
go.sum
@@ -26,6 +26,8 @@ github.com/d4l3k/go-bfloat16 v0.0.0-20211005043715-690c3bdd05f1/go.mod h1:uw2gLc
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
|
||||
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
|
||||
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
|
||||
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
|
||||
github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4=
|
||||
@@ -86,8 +88,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
|
||||
github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.0.0 h1:b4Gk+7WdP/d3HZH8EJsZpvV7EtDOgaZLtnaNGIu1adA=
|
||||
github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
|
||||
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@@ -95,9 +97,12 @@ github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHm
|
||||
github.com/jung-kurt/gofpdf v1.0.3-0.20190309125859-24315acbbda5/go.mod h1:7Id9E/uU8ce6rXgefFLlgrJj/GYY22cpxn+r32jIOes=
|
||||
github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
|
||||
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk=
|
||||
github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6 h1:ndNyv040zDGIDh8thGkXYjnFtiN02M1PVVF+JE/48xc=
|
||||
github.com/klauspost/cpuid/v2 v2.2.6/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0=
|
||||
@@ -115,6 +120,12 @@ github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
|
||||
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
|
||||
github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU=
|
||||
github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
|
||||
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
|
||||
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
|
||||
github.com/minio/minio-go/v7 v7.0.69 h1:l8AnsQFyY1xiwa/DaQskY4NXSLA2yrGsW5iD9nRPVS0=
|
||||
github.com/minio/minio-go/v7 v7.0.69/go.mod h1:XAvOPJQ5Xlzk5o3o/ArO2NMbhSGkimC+bpW/ngRKDmQ=
|
||||
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
|
||||
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
@@ -129,6 +140,7 @@ github.com/pdevine/tensor v0.0.0-20240228013915-64ccaa8d9ca9/go.mod h1:nR7l3gM6u
|
||||
github.com/pelletier/go-toml/v2 v2.0.1/go.mod h1:r9LEWfGN8R5k0VXJ+0BkIe7MYkRdwZOjgMj2KwnJFUo=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ=
|
||||
github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e h1:aoZm08cpOy4WuID//EZDgcC4zIxODThtZNPirFr42+A=
|
||||
github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
@@ -138,8 +150,11 @@ github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
|
||||
github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc=
|
||||
github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUAtL9R8=
|
||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
|
||||
github.com/rogpeppe/go-internal v1.8.1 h1:geMPLpDpQOgVyCg5z5GoRwLHepNdb71NXb67XFkP+Eg=
|
||||
github.com/rogpeppe/go-internal v1.8.1/go.mod h1:JeRgkft04UBgHMgCIwADu4Pn6Mtm5d4nPKWu0nJ5d+o=
|
||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
|
||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/spf13/cobra v1.7.0 h1:hyqWnYt1ZQShIddO5kBpj3vu05/++x6tJ6dg8EC572I=
|
||||
github.com/spf13/cobra v1.7.0/go.mod h1:uLxZILRyS/50WlhOIKD7W6V5bgeIt+4sICxh6uRMrb0=
|
||||
@@ -181,8 +196,8 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACk
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc=
|
||||
golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4=
|
||||
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
|
||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||
golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||
@@ -205,8 +220,8 @@ golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLL
|
||||
golang.org/x/net v0.0.0-20200904194848-62affa334b73/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA=
|
||||
golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.17.0 h1:pVaXccu2ozPjCXewfr1S7xza/zcXTity9cCdXQYSjIM=
|
||||
golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE=
|
||||
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
|
||||
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
|
||||
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
@@ -226,18 +241,18 @@ golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7w
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.13.0 h1:Af8nKPmuFypiUBjVoU9V20FiaFXOcuZI21p0ycVYYGE=
|
||||
golang.org/x/sys v0.13.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
|
||||
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.13.0 h1:bb+I9cTfFazGW51MZqBVmZy7+JEJMouUHTUSKVQLBek=
|
||||
golang.org/x/term v0.13.0/go.mod h1:LTmsnFJwVN6bCy1rVCoS+qHT1HhALEFxKncY3WNNh4U=
|
||||
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
|
||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
|
||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||
golang.org/x/tools v0.0.0-20180525024113-a5b4c53f6e8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
@@ -292,6 +307,8 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
|
||||
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
@@ -303,4 +320,6 @@ gorgonia.org/vecf64 v0.9.0 h1:bgZDP5x0OzBF64PjMGC3EvTdOoMEcmfAh1VCUnZFm1A=
|
||||
gorgonia.org/vecf64 v0.9.0/go.mod h1:hp7IOWCnRiVQKON73kkC/AUMtEXyf9kGlVrtPQ9ccVA=
|
||||
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
|
||||
kr.dev/diff v0.3.0 h1:o/T8/tkAq9IuRIuFqCupyKPC5iSY3WXpVZ2p6ZK3Emw=
|
||||
kr.dev/diff v0.3.0/go.mod h1:XiTaLOg2/PD0cmXY7WQXUR8RAF3RwWpqIQEj910J2NY=
|
||||
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
|
||||
|
||||
@@ -100,6 +100,8 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
updateLibPath(libDir)
|
||||
|
||||
gfxOverride := os.Getenv("HSA_OVERRIDE_GFX_VERSION")
|
||||
if gfxOverride == "" {
|
||||
supported, err := GetSupportedGFX(libDir)
|
||||
@@ -113,7 +115,7 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
if !slices.Contains[[]string, string](supported, v.ToGFXString()) {
|
||||
slog.Warn(fmt.Sprintf("amdgpu [%d] %s is not supported by %s %v", i, v.ToGFXString(), libDir, supported))
|
||||
// TODO - consider discrete markdown just for ROCM troubleshooting?
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/troubleshooting.md for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
slog.Warn("See https://github.com/ollama/ollama/blob/main/docs/gpu.md#overrides for HSA_OVERRIDE_GFX_VERSION usage")
|
||||
skip[i] = struct{}{}
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("amdgpu [%d] %s is supported", i, v.ToGFXString()))
|
||||
@@ -143,6 +145,21 @@ func AMDGetGPUInfo(resp *GpuInfo) {
|
||||
}
|
||||
}
|
||||
|
||||
func updateLibPath(libDir string) {
|
||||
ldPaths := []string{}
|
||||
if val, ok := os.LookupEnv("LD_LIBRARY_PATH"); ok {
|
||||
ldPaths = strings.Split(val, ":")
|
||||
}
|
||||
for _, d := range ldPaths {
|
||||
if d == libDir {
|
||||
return
|
||||
}
|
||||
}
|
||||
val := strings.Join(append(ldPaths, libDir), ":")
|
||||
slog.Debug("updated lib path", "LD_LIBRARY_PATH", val)
|
||||
os.Setenv("LD_LIBRARY_PATH", val)
|
||||
}
|
||||
|
||||
// Walk the sysfs nodes for the available GPUs and gather information from them
|
||||
// skipping over any devices in the skip map
|
||||
func amdProcMemLookup(resp *GpuInfo, skip map[int]interface{}, ids []int) {
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -21,11 +22,20 @@ var (
|
||||
func PayloadsDir() (string, error) {
|
||||
lock.Lock()
|
||||
defer lock.Unlock()
|
||||
var err error
|
||||
if payloadsDir == "" {
|
||||
cleanupTmpDirs()
|
||||
tmpDir, err := os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
tmpDir := os.Getenv("OLLAMA_TMPDIR")
|
||||
if tmpDir == "" {
|
||||
tmpDir, err = os.MkdirTemp("", "ollama")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir: %w", err)
|
||||
}
|
||||
} else {
|
||||
err = os.MkdirAll(tmpDir, 0755)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate tmp dir %s: %w", tmpDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Track our pid so we can clean up orphaned tmpdirs
|
||||
@@ -84,7 +94,12 @@ func Cleanup() {
|
||||
slog.Debug("cleaning up", "dir", tmpDir)
|
||||
err := os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
// On windows, if we remove too quickly the llama.dll may still be in-use and fail to remove
|
||||
time.Sleep(1000 * time.Millisecond)
|
||||
err = os.RemoveAll(tmpDir)
|
||||
if err != nil {
|
||||
slog.Warn("failed to clean up", "dir", tmpDir, "err", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
51
gpu/gpu.go
51
gpu/gpu.go
@@ -20,6 +20,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/format"
|
||||
)
|
||||
|
||||
type handles struct {
|
||||
@@ -27,8 +29,12 @@ type handles struct {
|
||||
cudart *C.cudart_handle_t
|
||||
}
|
||||
|
||||
const (
|
||||
cudaMinimumMemory = 457 * format.MebiByte
|
||||
rocmMinimumMemory = 457 * format.MebiByte
|
||||
)
|
||||
|
||||
var gpuMutex sync.Mutex
|
||||
var gpuHandles *handles = nil
|
||||
|
||||
// With our current CUDA compile flags, older than 5.0 will not work properly
|
||||
var CudaComputeMin = [2]C.int{5, 0}
|
||||
@@ -78,11 +84,11 @@ var CudartWindowsGlobs = []string{
|
||||
var CudaTegra string = os.Getenv("JETSON_JETPACK")
|
||||
|
||||
// Note: gpuMutex must already be held
|
||||
func initGPUHandles() {
|
||||
func initGPUHandles() *handles {
|
||||
|
||||
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
|
||||
|
||||
gpuHandles = &handles{nil, nil}
|
||||
gpuHandles := &handles{nil, nil}
|
||||
var nvmlMgmtName string
|
||||
var nvmlMgmtPatterns []string
|
||||
var cudartMgmtName string
|
||||
@@ -109,7 +115,7 @@ func initGPUHandles() {
|
||||
}
|
||||
cudartMgmtPatterns = append(cudartMgmtPatterns, CudartLinuxGlobs...)
|
||||
default:
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
slog.Info("Detecting GPU type")
|
||||
@@ -119,7 +125,7 @@ func initGPUHandles() {
|
||||
if cudart != nil {
|
||||
slog.Info("Nvidia GPU detected via cudart")
|
||||
gpuHandles.cudart = cudart
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
@@ -130,10 +136,10 @@ func initGPUHandles() {
|
||||
if nvml != nil {
|
||||
slog.Info("Nvidia GPU detected via nvidia-ml")
|
||||
gpuHandles.nvml = nvml
|
||||
return
|
||||
return gpuHandles
|
||||
}
|
||||
}
|
||||
|
||||
return gpuHandles
|
||||
}
|
||||
|
||||
func GetGPUInfo() GpuInfo {
|
||||
@@ -141,9 +147,16 @@ func GetGPUInfo() GpuInfo {
|
||||
// GPUs so we can report warnings if we see Nvidia/AMD but fail to load the libraries
|
||||
gpuMutex.Lock()
|
||||
defer gpuMutex.Unlock()
|
||||
if gpuHandles == nil {
|
||||
initGPUHandles()
|
||||
}
|
||||
|
||||
gpuHandles := initGPUHandles()
|
||||
defer func() {
|
||||
if gpuHandles.nvml != nil {
|
||||
C.nvml_release(*gpuHandles.nvml)
|
||||
}
|
||||
if gpuHandles.cudart != nil {
|
||||
C.cudart_release(*gpuHandles.cudart)
|
||||
}
|
||||
}()
|
||||
|
||||
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
|
||||
cpuVariant := GetCPUVariant()
|
||||
@@ -168,6 +181,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] NVML CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
resp.MinimumMemory = cudaMinimumMemory
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[nvidia-ml] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
@@ -187,6 +201,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDART CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
resp.Library = "cuda"
|
||||
resp.MinimumMemory = cudaMinimumMemory
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("[cudart] CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
|
||||
}
|
||||
@@ -194,6 +209,7 @@ func GetGPUInfo() GpuInfo {
|
||||
} else {
|
||||
AMDGetGPUInfo(&resp)
|
||||
if resp.Library != "" {
|
||||
resp.MinimumMemory = rocmMinimumMemory
|
||||
return resp
|
||||
}
|
||||
}
|
||||
@@ -239,20 +255,7 @@ func CheckVRAM() (int64, error) {
|
||||
}
|
||||
gpuInfo := GetGPUInfo()
|
||||
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
|
||||
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
|
||||
overhead := gpuInfo.FreeMemory / 10
|
||||
gpus := uint64(gpuInfo.DeviceCount)
|
||||
if overhead < gpus*1024*1024*1024 {
|
||||
overhead = gpus * 1024 * 1024 * 1024
|
||||
}
|
||||
// Assigning full reported free memory for Tegras due to OS controlled caching.
|
||||
if CudaTegra != "" {
|
||||
// Setting overhead for non-Tegra devices
|
||||
overhead = 0
|
||||
}
|
||||
avail := int64(gpuInfo.FreeMemory - overhead)
|
||||
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
|
||||
return avail, nil
|
||||
return int64(gpuInfo.FreeMemory), nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
|
||||
|
||||
@@ -62,6 +62,10 @@ void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp) {
|
||||
LOG(resp->ch.verbose, "cudaSetDevice err: %d\n", ret);
|
||||
UNLOAD_LIBRARY(resp->ch.handle);
|
||||
resp->ch.handle = NULL;
|
||||
if (ret == CUDA_ERROR_INSUFFICIENT_DRIVER) {
|
||||
resp->err = strdup("your nvidia driver is too old or missing, please upgrade to run ollama");
|
||||
return;
|
||||
}
|
||||
snprintf(buf, buflen, "cudart init failure: %d", ret);
|
||||
resp->err = strdup(buf);
|
||||
return;
|
||||
@@ -187,4 +191,10 @@ void cudart_compute_capability(cudart_handle_t h, cudart_compute_capability_t *r
|
||||
}
|
||||
}
|
||||
|
||||
void cudart_release(cudart_handle_t h) {
|
||||
LOG(h.verbose, "releasing cudart library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -7,6 +7,7 @@
|
||||
typedef enum cudartReturn_enum {
|
||||
CUDART_SUCCESS = 0,
|
||||
CUDART_UNSUPPORTED = 1,
|
||||
CUDA_ERROR_INSUFFICIENT_DRIVER = 35,
|
||||
// Other values omitted for now...
|
||||
} cudartReturn_t;
|
||||
|
||||
@@ -54,6 +55,7 @@ typedef struct cudart_compute_capability {
|
||||
void cudart_init(char *cudart_lib_path, cudart_init_resp_t *resp);
|
||||
void cudart_check_vram(cudart_handle_t ch, mem_info_t *resp);
|
||||
void cudart_compute_capability(cudart_handle_t th, cudart_compute_capability_t *cc);
|
||||
void cudart_release(cudart_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_CUDART_H__
|
||||
#endif // __APPLE__
|
||||
|
||||
@@ -211,4 +211,11 @@ void nvml_compute_capability(nvml_handle_t h, nvml_compute_capability_t *resp) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void nvml_release(nvml_handle_t h) {
|
||||
LOG(h.verbose, "releasing nvml library\n");
|
||||
UNLOAD_LIBRARY(h.handle);
|
||||
h.handle = NULL;
|
||||
}
|
||||
|
||||
#endif // __APPLE__
|
||||
@@ -51,6 +51,7 @@ typedef struct nvml_compute_capability {
|
||||
void nvml_init(char *nvml_lib_path, nvml_init_resp_t *resp);
|
||||
void nvml_check_vram(nvml_handle_t ch, mem_info_t *resp);
|
||||
void nvml_compute_capability(nvml_handle_t ch, nvml_compute_capability_t *cc);
|
||||
void nvml_release(nvml_handle_t ch);
|
||||
|
||||
#endif // __GPU_INFO_NVML_H__
|
||||
#endif // __APPLE__
|
||||
@@ -14,6 +14,9 @@ type GpuInfo struct {
|
||||
// Optional variant to select (e.g. versions, cpu feature flags)
|
||||
Variant string `json:"variant,omitempty"`
|
||||
|
||||
// MinimumMemory represents the minimum memory required to use the GPU
|
||||
MinimumMemory int64 `json:"-"`
|
||||
|
||||
// TODO add other useful attributes about the card here for discovery information
|
||||
}
|
||||
|
||||
|
||||
@@ -24,5 +24,5 @@ func TestOrcaMiniBlueSky(t *testing.T) {
|
||||
"seed": 123,
|
||||
},
|
||||
}
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh"})
|
||||
GenerateTestHelper(ctx, t, &http.Client{}, req, []string{"rayleigh", "scattering"})
|
||||
}
|
||||
|
||||
@@ -126,7 +126,7 @@ func StartServer(ctx context.Context, ollamaHost string) error {
|
||||
}
|
||||
|
||||
func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoint, modelName string) error {
|
||||
slog.Debug("checking status of model", "model", modelName)
|
||||
slog.Info("checking status of model", "model", modelName)
|
||||
showReq := &api.ShowRequest{Name: modelName}
|
||||
requestJSON, err := json.Marshal(showReq)
|
||||
if err != nil {
|
||||
@@ -174,36 +174,51 @@ func PullIfMissing(ctx context.Context, client *http.Client, scheme, testEndpoin
|
||||
return nil
|
||||
}
|
||||
|
||||
var serverProcMutex sync.Mutex
|
||||
|
||||
func GenerateTestHelper(ctx context.Context, t *testing.T, client *http.Client, genReq api.GenerateRequest, anyResp []string) {
|
||||
|
||||
// TODO maybe stuff in an init routine?
|
||||
lifecycle.InitLogging()
|
||||
|
||||
requestJSON, err := json.Marshal(genReq)
|
||||
if err != nil {
|
||||
t.Fatalf("Error serializing request: %v", err)
|
||||
}
|
||||
defer func() {
|
||||
if t.Failed() && os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
// TODO
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
defer serverProcMutex.Unlock()
|
||||
if t.Failed() {
|
||||
fp, err := os.Open(lifecycle.ServerLogFile)
|
||||
if err != nil {
|
||||
slog.Error("failed to open server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
data, err := io.ReadAll(fp)
|
||||
if err != nil {
|
||||
slog.Error("failed to read server log", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
return
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
slog.Warn("SERVER LOG FOLLOWS")
|
||||
os.Stderr.Write(data)
|
||||
slog.Warn("END OF SERVER")
|
||||
}
|
||||
err = os.Remove(lifecycle.ServerLogFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
slog.Warn("failed to cleanup", "logfile", lifecycle.ServerLogFile, "error", err)
|
||||
}
|
||||
}()
|
||||
scheme, testEndpoint := GetTestEndpoint()
|
||||
|
||||
if os.Getenv("OLLAMA_TEST_EXISTING") == "" {
|
||||
serverProcMutex.Lock()
|
||||
fp, err := os.CreateTemp("", "ollama-server-*.log")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to generate log file: %s", err)
|
||||
}
|
||||
lifecycle.ServerLogFile = fp.Name()
|
||||
fp.Close()
|
||||
assert.NoError(t, StartServer(ctx, testEndpoint))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#ifdef __linux__
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#elif _WIN32
|
||||
#include <windows.h>
|
||||
#define LOAD_LIBRARY(lib, flags) LoadLibrary(lib)
|
||||
#define LOAD_SYMBOL(handle, sym) GetProcAddress(handle, sym)
|
||||
#define UNLOAD_LIBRARY(handle) FreeLibrary(handle)
|
||||
#define LOAD_ERR() ({\
|
||||
LPSTR messageBuffer = NULL; \
|
||||
size_t size = FormatMessageA(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS, \
|
||||
NULL, GetLastError(), MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), (LPSTR)&messageBuffer, 0, NULL); \
|
||||
char *resp = strdup(messageBuffer); \
|
||||
LocalFree(messageBuffer); \
|
||||
resp; \
|
||||
})
|
||||
#else
|
||||
#include <dlfcn.h>
|
||||
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
|
||||
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
|
||||
#define LOAD_ERR() strdup(dlerror())
|
||||
#define UNLOAD_LIBRARY(handle) dlclose(handle)
|
||||
#endif
|
||||
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err) {
|
||||
int i = 0;
|
||||
struct lookup {
|
||||
char *s;
|
||||
void **p;
|
||||
} l[] = {
|
||||
{"llama_server_init", (void *)&s->llama_server_init},
|
||||
{"llama_server_start", (void *)&s->llama_server_start},
|
||||
{"llama_server_stop", (void *)&s->llama_server_stop},
|
||||
{"llama_server_completion", (void *)&s->llama_server_completion},
|
||||
{"llama_server_completion_next_result",
|
||||
(void *)&s->llama_server_completion_next_result},
|
||||
{"llama_server_completion_cancel",
|
||||
(void *)&s->llama_server_completion_cancel},
|
||||
{"llama_server_release_task_result",
|
||||
(void *)&s->llama_server_release_task_result},
|
||||
{"llama_server_tokenize", (void *)&s->llama_server_tokenize},
|
||||
{"llama_server_detokenize", (void *)&s->llama_server_detokenize},
|
||||
{"llama_server_embedding", (void *)&s->llama_server_embedding},
|
||||
{"llama_server_release_json_resp",
|
||||
(void *)&s->llama_server_release_json_resp},
|
||||
{"", NULL},
|
||||
};
|
||||
|
||||
printf("loading library %s\n", libPath);
|
||||
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
|
||||
if (!s->handle) {
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unable to load dynamic server library: %s", msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
|
||||
for (i = 0; l[i].p != NULL; i++) {
|
||||
*l[i].p = LOAD_SYMBOL(s->handle, l[i].s);
|
||||
if (!l[i].p) {
|
||||
UNLOAD_LIBRARY(s->handle);
|
||||
err->id = -1;
|
||||
char *msg = LOAD_ERR();
|
||||
snprintf(err->msg, err->msg_len, "symbol lookup for %s failed: %s",
|
||||
l[i].s, msg);
|
||||
free(msg);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_init(sparams, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_start(struct dynamic_llama_server s) {
|
||||
s.llama_server_start();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
|
||||
s.llama_server_stop();
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp) {
|
||||
s.llama_server_completion(json_req, resp);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result) {
|
||||
s.llama_server_completion_next_result(task_id, result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_completion_cancel(
|
||||
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
|
||||
s.llama_server_completion_cancel(task_id, err);
|
||||
}
|
||||
inline void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result) {
|
||||
s.llama_server_release_task_result(result);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_tokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_detokenize(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
s.llama_server_embedding(json_req, json_resp, err);
|
||||
}
|
||||
|
||||
inline void dyn_llama_server_release_json_resp(
|
||||
struct dynamic_llama_server s, char **json_resp) {
|
||||
s.llama_server_release_json_resp(json_resp);
|
||||
}
|
||||
@@ -1,388 +0,0 @@
|
||||
package llm
|
||||
|
||||
/*
|
||||
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
|
||||
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
|
||||
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
|
||||
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
|
||||
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
|
||||
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
|
||||
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
|
||||
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
|
||||
#cgo linux CFLAGS: -D_GNU_SOURCE
|
||||
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
|
||||
#cgo linux windows LDFLAGS: -lpthread
|
||||
|
||||
#include <stdlib.h>
|
||||
#include "dyn_ext_server.h"
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
type dynExtServer struct {
|
||||
s C.struct_dynamic_llama_server
|
||||
options api.Options
|
||||
}
|
||||
|
||||
// Note: current implementation does not support concurrent instantiations
|
||||
var mutex sync.Mutex
|
||||
|
||||
func newExtServerResp(len C.size_t) C.ext_server_resp_t {
|
||||
var resp C.ext_server_resp_t
|
||||
resp.msg_len = len
|
||||
bytes := make([]byte, len)
|
||||
resp.msg = (*C.char)(C.CBytes(bytes))
|
||||
return resp
|
||||
}
|
||||
|
||||
func freeExtServerResp(resp C.ext_server_resp_t) {
|
||||
if resp.msg_len == 0 {
|
||||
return
|
||||
}
|
||||
C.free(unsafe.Pointer(resp.msg))
|
||||
}
|
||||
|
||||
func extServerResponseToErr(resp C.ext_server_resp_t) error {
|
||||
return fmt.Errorf(C.GoString(resp.msg))
|
||||
}
|
||||
|
||||
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if !mutex.TryLock() {
|
||||
slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete")
|
||||
mutex.Lock()
|
||||
}
|
||||
gpu.UpdatePath(filepath.Dir(library))
|
||||
libPath := C.CString(library)
|
||||
defer C.free(unsafe.Pointer(libPath))
|
||||
resp := newExtServerResp(512)
|
||||
defer freeExtServerResp(resp)
|
||||
var srv C.struct_dynamic_llama_server
|
||||
C.dyn_init(libPath, &srv, &resp)
|
||||
if resp.id < 0 {
|
||||
mutex.Unlock()
|
||||
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
|
||||
}
|
||||
llm := dynExtServer{
|
||||
s: srv,
|
||||
options: opts,
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Loading Dynamic llm server: %s", library))
|
||||
|
||||
var sparams C.ext_server_params_t
|
||||
sparams.model = C.CString(model)
|
||||
defer C.free(unsafe.Pointer(sparams.model))
|
||||
|
||||
sparams.embedding = true
|
||||
sparams.n_ctx = C.uint(opts.NumCtx)
|
||||
sparams.n_batch = C.uint(opts.NumBatch)
|
||||
sparams.n_gpu_layers = C.int(opts.NumGPU)
|
||||
sparams.main_gpu = C.int(opts.MainGPU)
|
||||
sparams.n_parallel = 1 // TODO - wire up concurrency
|
||||
|
||||
// Always use the value encoded in the model
|
||||
sparams.rope_freq_base = 0.0
|
||||
sparams.rope_freq_scale = 0.0
|
||||
sparams.memory_f16 = C.bool(opts.F16KV)
|
||||
sparams.use_mlock = C.bool(opts.UseMLock)
|
||||
sparams.use_mmap = C.bool(opts.UseMMap)
|
||||
|
||||
if opts.UseNUMA {
|
||||
sparams.numa = C.int(1)
|
||||
} else {
|
||||
sparams.numa = C.int(0)
|
||||
}
|
||||
|
||||
sparams.lora_adapters = nil
|
||||
for i := 0; i < len(adapters); i++ {
|
||||
la := (*C.ext_server_lora_adapter_t)(C.malloc(C.sizeof_ext_server_lora_adapter_t))
|
||||
defer C.free(unsafe.Pointer(la))
|
||||
la.adapter = C.CString(adapters[i])
|
||||
defer C.free(unsafe.Pointer(la.adapter))
|
||||
la.scale = C.float(1.0) // TODO expose scale/weights up through ollama UX
|
||||
la.next = nil
|
||||
if i == 0 {
|
||||
sparams.lora_adapters = la
|
||||
} else {
|
||||
tmp := sparams.lora_adapters
|
||||
for ; tmp.next != nil; tmp = tmp.next {
|
||||
}
|
||||
tmp.next = la
|
||||
}
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
sparams.mmproj = C.CString(projectors[0])
|
||||
defer C.free(unsafe.Pointer(sparams.mmproj))
|
||||
} else {
|
||||
sparams.mmproj = nil
|
||||
}
|
||||
|
||||
sparams.n_threads = C.uint(opts.NumThread)
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
sparams.verbose_logging = C.bool(true)
|
||||
} else {
|
||||
sparams.verbose_logging = C.bool(false)
|
||||
}
|
||||
|
||||
slog.Info("Initializing llama server")
|
||||
slog.Debug(fmt.Sprintf("server params: %+v", sparams))
|
||||
initResp := newExtServerResp(512)
|
||||
defer freeExtServerResp(initResp)
|
||||
C.dyn_llama_server_init(llm.s, &sparams, &initResp)
|
||||
if initResp.id < 0 {
|
||||
mutex.Unlock()
|
||||
err := extServerResponseToErr(initResp)
|
||||
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
slog.Info("Starting llama main loop")
|
||||
C.dyn_llama_server_start(llm.s)
|
||||
return &llm, nil
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
|
||||
if len(predict.Images) > 0 {
|
||||
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
|
||||
}
|
||||
|
||||
request := map[string]any{
|
||||
"prompt": predict.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": predict.Options.NumPredict,
|
||||
"n_keep": predict.Options.NumKeep,
|
||||
"temperature": predict.Options.Temperature,
|
||||
"top_k": predict.Options.TopK,
|
||||
"top_p": predict.Options.TopP,
|
||||
"tfs_z": predict.Options.TFSZ,
|
||||
"typical_p": predict.Options.TypicalP,
|
||||
"repeat_last_n": predict.Options.RepeatLastN,
|
||||
"repeat_penalty": predict.Options.RepeatPenalty,
|
||||
"presence_penalty": predict.Options.PresencePenalty,
|
||||
"frequency_penalty": predict.Options.FrequencyPenalty,
|
||||
"mirostat": predict.Options.Mirostat,
|
||||
"mirostat_tau": predict.Options.MirostatTau,
|
||||
"mirostat_eta": predict.Options.MirostatEta,
|
||||
"penalize_nl": predict.Options.PenalizeNewline,
|
||||
"seed": predict.Options.Seed,
|
||||
"stop": predict.Options.Stop,
|
||||
"image_data": predict.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
if predict.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(predict.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
}
|
||||
|
||||
retryDelay := 100 * time.Microsecond
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
retryDelay *= 2 // exponential backoff
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(buffer.String())
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
|
||||
C.dyn_llama_server_completion(llm.s, req, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
}
|
||||
|
||||
retryNeeded := false
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
out:
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return cancelCompletion(llm, resp)
|
||||
default:
|
||||
var result C.ext_server_task_result_t
|
||||
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
|
||||
json_resp := C.GoString(result.json_resp)
|
||||
C.dyn_llama_server_release_task_result(llm.s, &result)
|
||||
|
||||
var p prediction
|
||||
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
|
||||
} else {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if bool(result.error) && strings.Contains(json_resp, "slot unavailable") {
|
||||
retryNeeded = true
|
||||
// task will already be canceled
|
||||
break out
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(p.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(p.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return cancelCompletion(llm, resp)
|
||||
}
|
||||
|
||||
if p.Content != "" {
|
||||
fn(PredictResult{
|
||||
Content: p.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if p.Stop || bool(result.stop) {
|
||||
fn(PredictResult{
|
||||
Done: true,
|
||||
PromptEvalCount: p.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(p.Timings.PromptMS),
|
||||
EvalCount: p.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(p.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
if !retryNeeded {
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
func cancelCompletion(llm *dynExtServer, resp C.ext_server_resp_t) error {
|
||||
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
|
||||
if resp.id < 0 {
|
||||
return extServerResponseToErr(resp)
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err2)
|
||||
}
|
||||
|
||||
return encoded.Tokens, err
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
if len(tokens) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return "", extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err2)
|
||||
}
|
||||
|
||||
return decoded.Content, err
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
data, err := json.Marshal(TokenizeRequest{Content: input})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req := C.CString(string(data))
|
||||
defer C.free(unsafe.Pointer(req))
|
||||
var json_resp *C.char
|
||||
resp := newExtServerResp(128)
|
||||
defer freeExtServerResp(resp)
|
||||
C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
|
||||
if resp.id < 0 {
|
||||
return nil, extServerResponseToErr(resp)
|
||||
}
|
||||
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
func (llm *dynExtServer) Close() {
|
||||
C.dyn_llama_server_stop(llm.s)
|
||||
mutex.Unlock()
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
#include <stdlib.h>
|
||||
|
||||
#include "ext_server.h"
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
struct dynamic_llama_server {
|
||||
void *handle;
|
||||
void (*llama_server_init)(ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_start)();
|
||||
void (*llama_server_stop)();
|
||||
void (*llama_server_completion)(const char *json_req,
|
||||
ext_server_resp_t *resp);
|
||||
void (*llama_server_completion_next_result)(const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
void (*llama_server_completion_cancel)(const int task_id,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_release_task_result)(ext_server_task_result_t *result);
|
||||
void (*llama_server_tokenize)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_detokenize)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_embedding)(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void (*llama_server_release_json_resp)(char **json_resp);
|
||||
};
|
||||
|
||||
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
// No good way to call C function pointers from Go so inline the indirection
|
||||
void dyn_llama_server_init(struct dynamic_llama_server s,
|
||||
ext_server_params_t *sparams,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_start(struct dynamic_llama_server s);
|
||||
|
||||
void dyn_llama_server_stop(struct dynamic_llama_server s);
|
||||
|
||||
void dyn_llama_server_completion(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
ext_server_resp_t *resp);
|
||||
|
||||
void dyn_llama_server_completion_next_result(
|
||||
struct dynamic_llama_server s, const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
|
||||
void dyn_llama_server_completion_cancel(struct dynamic_llama_server s,
|
||||
const int task_id,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_release_task_result(
|
||||
struct dynamic_llama_server s, ext_server_task_result_t *result);
|
||||
|
||||
void dyn_llama_server_tokenize(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_detokenize(struct dynamic_llama_server s,
|
||||
const char *json_req,
|
||||
char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
|
||||
void dyn_llama_server_embedding(struct dynamic_llama_server s,
|
||||
const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void dyn_llama_server_release_json_resp(struct dynamic_llama_server s,
|
||||
char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
27
llm/ext_server/CMakeLists.txt
vendored
27
llm/ext_server/CMakeLists.txt
vendored
@@ -1,21 +1,14 @@
|
||||
|
||||
set(TARGET ext_server)
|
||||
set(TARGET ollama_llama_server)
|
||||
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
|
||||
include_directories(${CMAKE_CURRENT_SOURCE_DIR})
|
||||
add_executable(${TARGET} server.cpp utils.hpp json.hpp httplib.h)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_compile_definitions(${TARGET} PRIVATE
|
||||
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
|
||||
)
|
||||
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
|
||||
if (WIN32)
|
||||
add_library(${TARGET} SHARED ext_server.cpp ../llama.cpp/llama.cpp)
|
||||
else()
|
||||
add_library(${TARGET} STATIC ext_server.cpp ../llama.cpp/llama.cpp)
|
||||
TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32)
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
|
||||
target_link_libraries(${TARGET} PRIVATE ggml llava common )
|
||||
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>)
|
||||
install(TARGETS ext_server LIBRARY)
|
||||
|
||||
if (CUDAToolkit_FOUND)
|
||||
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
if (WIN32)
|
||||
target_link_libraries(${TARGET} PRIVATE nvml)
|
||||
endif()
|
||||
endif()
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
18
llm/ext_server/README.md
vendored
18
llm/ext_server/README.md
vendored
@@ -1,18 +0,0 @@
|
||||
# Extern C Server
|
||||
|
||||
This directory contains a thin facade we layer on top of the Llama.cpp server to
|
||||
expose `extern C` interfaces to access the functionality through direct API
|
||||
calls in-process. The llama.cpp code uses compile time macros to configure GPU
|
||||
type along with other settings. During the `go generate ./...` execution, the
|
||||
build will generate one or more copies of the llama.cpp `extern C` server based
|
||||
on what GPU libraries are detected to support multiple GPU types as well as CPU
|
||||
only support. The Ollama go build then embeds these different servers to support
|
||||
different GPUs and settings at runtime.
|
||||
|
||||
If you are making changes to the code in this directory, make sure to disable
|
||||
caching during your go build to ensure you pick up your changes. A typical
|
||||
iteration cycle from the top of the source tree looks like:
|
||||
|
||||
```
|
||||
go generate ./... && go build -a .
|
||||
```
|
||||
377
llm/ext_server/ext_server.cpp
vendored
377
llm/ext_server/ext_server.cpp
vendored
@@ -1,377 +0,0 @@
|
||||
#include "ext_server.h"
|
||||
#include <atomic>
|
||||
|
||||
// Necessary evil since the server types are not defined in a header
|
||||
#include "server.cpp"
|
||||
|
||||
// Low level API access to verify GPU access
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
#include "rocblas/rocblas.h"
|
||||
#endif // __HIP_PLATFORM_AMD__
|
||||
#define cudaGetDevice hipGetDevice
|
||||
#define cudaError_t hipError_t
|
||||
#define cudaSuccess hipSuccess
|
||||
#define cudaGetErrorString hipGetErrorString
|
||||
#else
|
||||
#include <cuda_runtime.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#endif // defined(GGML_USE_HIPBLAS)
|
||||
#endif // GGML_USE_CUBLAS
|
||||
|
||||
// Expose the llama server as a callable extern "C" API
|
||||
llama_server_context *llama = NULL;
|
||||
std::thread ext_server_thread;
|
||||
bool shutting_down = false;
|
||||
std::atomic_int recv_counter;
|
||||
|
||||
// RAII wrapper for tracking in-flight recv calls
|
||||
class atomicRecv {
|
||||
public:
|
||||
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
|
||||
++this->atomic;
|
||||
}
|
||||
~atomicRecv() {
|
||||
--this->atomic;
|
||||
}
|
||||
private:
|
||||
std::atomic<int> &atomic;
|
||||
};
|
||||
|
||||
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
|
||||
recv_counter = 0;
|
||||
assert(err != NULL && sparams != NULL);
|
||||
log_set_target(stderr);
|
||||
if (!sparams->verbose_logging) {
|
||||
server_verbose = true;
|
||||
log_disable();
|
||||
}
|
||||
|
||||
LOG_TEE("system info: %s\n", llama_print_system_info());
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama = new llama_server_context;
|
||||
gpt_params params;
|
||||
params.n_ctx = sparams->n_ctx;
|
||||
params.n_batch = sparams->n_batch;
|
||||
if (sparams->n_threads > 0) {
|
||||
params.n_threads = sparams->n_threads;
|
||||
}
|
||||
params.n_parallel = sparams->n_parallel;
|
||||
params.rope_freq_base = sparams->rope_freq_base;
|
||||
params.rope_freq_scale = sparams->rope_freq_scale;
|
||||
|
||||
if (sparams->memory_f16) {
|
||||
params.cache_type_k = "f16";
|
||||
params.cache_type_v = "f16";
|
||||
} else {
|
||||
params.cache_type_k = "f32";
|
||||
params.cache_type_v = "f32";
|
||||
}
|
||||
|
||||
params.n_gpu_layers = sparams->n_gpu_layers;
|
||||
params.main_gpu = sparams->main_gpu;
|
||||
params.use_mlock = sparams->use_mlock;
|
||||
params.use_mmap = sparams->use_mmap;
|
||||
params.numa = (ggml_numa_strategy)sparams->numa;
|
||||
params.embedding = sparams->embedding;
|
||||
if (sparams->model != NULL) {
|
||||
params.model = sparams->model;
|
||||
}
|
||||
|
||||
if (sparams->lora_adapters != NULL) {
|
||||
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
|
||||
la = la->next) {
|
||||
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
|
||||
}
|
||||
|
||||
params.use_mmap = false;
|
||||
}
|
||||
|
||||
if (sparams->mmproj != NULL) {
|
||||
params.mmproj = std::string(sparams->mmproj);
|
||||
}
|
||||
|
||||
#if defined(GGML_USE_CUBLAS)
|
||||
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
|
||||
LOG_TEE("Performing pre-initialization of GPU\n");
|
||||
int id;
|
||||
cudaError_t cudaErr = cudaGetDevice(&id);
|
||||
if (cudaErr != cudaSuccess) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
if (!llama->load_model(params)) {
|
||||
// an error occurred that was not thrown
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "error loading model %s", params.model.c_str());
|
||||
return;
|
||||
}
|
||||
|
||||
llama->initialize();
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception initializing llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_start() {
|
||||
assert(llama != NULL);
|
||||
// TODO mutex to protect thread creation
|
||||
ext_server_thread = std::thread([&]() {
|
||||
try {
|
||||
LOG_TEE("llama server main loop starting\n");
|
||||
ggml_time_init();
|
||||
llama->queue_tasks.on_new_task(std::bind(
|
||||
&llama_server_context::process_single_task, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_finish_multitask(std::bind(
|
||||
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
|
||||
llama->queue_tasks.on_run_slots(std::bind(
|
||||
&llama_server_context::update_slots, llama));
|
||||
llama->queue_results.on_multitask_update(std::bind(
|
||||
&llama_server_queue::update_multitask,
|
||||
&llama->queue_tasks,
|
||||
std::placeholders::_1,
|
||||
std::placeholders::_2,
|
||||
std::placeholders::_3
|
||||
));
|
||||
llama->queue_tasks.start_loop();
|
||||
} catch (std::exception &e) {
|
||||
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
|
||||
} catch (...) {
|
||||
LOG_TEE("caught unknown exception in llama server main loop\n");
|
||||
}
|
||||
LOG_TEE("\nllama server shutting down\n");
|
||||
llama_backend_free();
|
||||
});
|
||||
}
|
||||
|
||||
void llama_server_stop() {
|
||||
assert(llama != NULL);
|
||||
// Shutdown any in-flight requests and block incoming requests.
|
||||
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
|
||||
shutting_down = true;
|
||||
|
||||
while (recv_counter.load() > 0) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(50));
|
||||
}
|
||||
|
||||
// This may take a while for any pending tasks to drain
|
||||
// TODO - consider a timeout to cancel tasks if it's taking too long
|
||||
llama->queue_tasks.terminate();
|
||||
ext_server_thread.join();
|
||||
delete llama;
|
||||
llama = NULL;
|
||||
LOG_TEE("llama server shutdown complete\n");
|
||||
shutting_down = false;
|
||||
}
|
||||
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
|
||||
assert(llama != NULL && json_req != NULL && resp != NULL);
|
||||
resp->id = -1;
|
||||
resp->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
json data = json::parse(json_req);
|
||||
resp->id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(resp->id);
|
||||
llama->request_completion(resp->id, data, false, false, -1);
|
||||
} catch (std::exception &e) {
|
||||
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
snprintf(resp->msg, resp->msg_len, "Unknown exception during completion");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *resp) {
|
||||
assert(llama != NULL && resp != NULL);
|
||||
resp->id = -1;
|
||||
resp->stop = false;
|
||||
resp->error = false;
|
||||
resp->json_resp = NULL;
|
||||
std::string result_json;
|
||||
try {
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
result_json =
|
||||
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
|
||||
resp->id = result.id;
|
||||
resp->stop = result.stop;
|
||||
resp->error = result.error;
|
||||
if (result.error) {
|
||||
LOG_TEE("next result cancel on error\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (result.stop) {
|
||||
LOG_TEE("next result cancel on stop\n");
|
||||
llama->request_cancel(task_id);
|
||||
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} else if (shutting_down) {
|
||||
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
resp->stop = true;
|
||||
}
|
||||
} catch (std::exception &e) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"exception " + std::string(e.what()) + "\"}";
|
||||
LOG_TEE("llama server completion exception %s\n", e.what());
|
||||
} catch (...) {
|
||||
resp->error = true;
|
||||
resp->id = -1;
|
||||
result_json = "{\"error\":\"Unknown exception during completion\"}";
|
||||
LOG_TEE("llama server completion unknown exception\n");
|
||||
}
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
resp->json_resp = new char[size];
|
||||
snprintf(resp->json_resp, size, "%s", result_json.c_str());
|
||||
}
|
||||
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result) {
|
||||
if (result == NULL || result->json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] result->json_resp;
|
||||
}
|
||||
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
|
||||
assert(llama != NULL && err != NULL);
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
llama->request_cancel(task_id);
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len,
|
||||
"Unknown exception completion cancel in llama server");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::vector<llama_token> tokens;
|
||||
if (body.count("content") != 0) {
|
||||
tokens = llama->tokenize(body["content"], false);
|
||||
}
|
||||
const json data = format_tokenizer_response(tokens);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during tokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_release_json_resp(char **json_resp) {
|
||||
if (json_resp == NULL || *json_resp == NULL) {
|
||||
return;
|
||||
}
|
||||
delete[] *json_resp;
|
||||
}
|
||||
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
std::string content;
|
||||
if (body.count("tokens") != 0) {
|
||||
const std::vector<llama_token> tokens = body["tokens"];
|
||||
content = tokens_to_str(llama->ctx, tokens.cbegin(), tokens.cend());
|
||||
}
|
||||
const json data = format_detokenized_response(content);
|
||||
std::string result_json = data.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during detokenize");
|
||||
}
|
||||
}
|
||||
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err) {
|
||||
assert(llama != NULL && json_req != NULL && json_resp != NULL && err != NULL);
|
||||
*json_resp = NULL;
|
||||
err->id = 0;
|
||||
err->msg[0] = '\0';
|
||||
try {
|
||||
if (shutting_down) {
|
||||
throw std::runtime_error("server shutting down");
|
||||
}
|
||||
const json body = json::parse(json_req);
|
||||
json prompt;
|
||||
if (body.count("content") != 0) {
|
||||
prompt = body["content"];
|
||||
} else {
|
||||
prompt = "";
|
||||
}
|
||||
const int task_id = llama->queue_tasks.get_new_id();
|
||||
llama->queue_results.add_waiting_task_id(task_id);
|
||||
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
|
||||
atomicRecv ar(recv_counter);
|
||||
task_result result = llama->queue_results.recv(task_id);
|
||||
std::string result_json = result.result_json.dump();
|
||||
const std::string::size_type size = result_json.size() + 1;
|
||||
*json_resp = new char[size];
|
||||
snprintf(*json_resp, size, "%s", result_json.c_str());
|
||||
llama->queue_results.remove_waiting_task_id(task_id);
|
||||
} catch (std::exception &e) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "exception %s", e.what());
|
||||
} catch (...) {
|
||||
err->id = -1;
|
||||
snprintf(err->msg, err->msg_len, "Unknown exception during embedding");
|
||||
}
|
||||
}
|
||||
95
llm/ext_server/ext_server.h
vendored
95
llm/ext_server/ext_server.h
vendored
@@ -1,95 +0,0 @@
|
||||
#if defined(LLAMA_SERVER_LIBRARY)
|
||||
#ifndef LLAMA_SERVER_H
|
||||
#define LLAMA_SERVER_H
|
||||
#include <stdbool.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
int __main(int argc, char **argv);
|
||||
|
||||
// This exposes extern C entrypoints into the llama_server
|
||||
// To enable the server compile with LLAMA_SERVER_LIBRARY
|
||||
|
||||
#ifdef __cplusplus
|
||||
extern "C" {
|
||||
#endif
|
||||
typedef struct ext_server_resp {
|
||||
int id; // < 0 on error
|
||||
size_t msg_len; // caller must allocate msg and set msg_len
|
||||
char *msg;
|
||||
} ext_server_resp_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_lora_adapter {
|
||||
char *adapter;
|
||||
float scale;
|
||||
struct ext_server_lora_adapter *next;
|
||||
} ext_server_lora_adapter_t;
|
||||
|
||||
// Allocated and freed by caller
|
||||
typedef struct ext_server_params {
|
||||
char *model;
|
||||
uint32_t n_ctx; // token context window, 0 = from model
|
||||
uint32_t n_batch; // prompt processing maximum batch size
|
||||
uint32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_parallel; // number of parallel sequences to decodewra
|
||||
float rope_freq_base; // RoPE base frequency, 0 = from model
|
||||
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
|
||||
bool memory_f16; // use f16 instead of f32 for memory kv
|
||||
int32_t n_gpu_layers; // number of layers to store in VRAM (-1 - use default)
|
||||
int32_t main_gpu; // the GPU that is used for scratch and small tensors
|
||||
bool use_mlock; // force system to keep model in RAM
|
||||
bool use_mmap; // use mmap if possible
|
||||
int numa; // attempt optimizations that help on some NUMA systems
|
||||
bool embedding; // get only sentence embedding
|
||||
ext_server_lora_adapter_t *lora_adapters;
|
||||
char *mmproj;
|
||||
bool verbose_logging; // Enable verbose logging of the server
|
||||
} ext_server_params_t;
|
||||
|
||||
typedef struct ext_server_task_result {
|
||||
int id;
|
||||
bool stop;
|
||||
bool error;
|
||||
char *json_resp; // null terminated, memory managed by ext_server
|
||||
} ext_server_task_result_t;
|
||||
|
||||
// Initialize the server once per process
|
||||
// err->id = 0 for success and err->msg[0] = NULL
|
||||
// err->id != 0 for failure, and err->msg contains error message
|
||||
void llama_server_init(ext_server_params_t *sparams, ext_server_resp_t *err);
|
||||
|
||||
// Run the main loop, called once per init
|
||||
void llama_server_start();
|
||||
// Stop the main loop and free up resources allocated in init and start. Init
|
||||
// must be called again to reuse
|
||||
void llama_server_stop();
|
||||
|
||||
// json_req null terminated string, memory managed by caller
|
||||
// resp->id >= 0 on success (task ID)
|
||||
// resp->id < 0 on error, and resp->msg contains error message
|
||||
void llama_server_completion(const char *json_req, ext_server_resp_t *resp);
|
||||
|
||||
// Caller must call llama_server_release_task_result to free resp->json_resp
|
||||
void llama_server_completion_next_result(const int task_id,
|
||||
ext_server_task_result_t *result);
|
||||
void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err);
|
||||
void llama_server_release_task_result(ext_server_task_result_t *result);
|
||||
|
||||
// Caller must call llama_server_releaes_json_resp to free json_resp if err.id <
|
||||
// 0
|
||||
void llama_server_tokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_detokenize(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_embedding(const char *json_req, char **json_resp,
|
||||
ext_server_resp_t *err);
|
||||
void llama_server_release_json_resp(char **json_resp);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif // LLAMA_SERVER_LIBRARY
|
||||
6
llm/ext_server/server.cpp
vendored
6
llm/ext_server/server.cpp
vendored
@@ -1007,13 +1007,15 @@ struct llama_server_context
|
||||
slot.n_sent_text += result.text_to_send.size();
|
||||
// add the token to slot queue and cache
|
||||
}
|
||||
slot.add_token_string(result);
|
||||
|
||||
if (slot.params.stream)
|
||||
{
|
||||
send_partial_response(slot, result);
|
||||
}
|
||||
}
|
||||
|
||||
slot.add_token_string(result);
|
||||
|
||||
if (incomplete)
|
||||
{
|
||||
slot.has_next_token = true;
|
||||
@@ -2768,7 +2770,7 @@ inline void signal_handler(int signal) {
|
||||
shutdown_handler(signal);
|
||||
}
|
||||
|
||||
int _main(int argc, char **argv)
|
||||
int main(int argc, char **argv)
|
||||
{
|
||||
#if SERVER_VERBOSE != 1
|
||||
log_disable();
|
||||
|
||||
@@ -14,7 +14,7 @@ init_vars() {
|
||||
|
||||
LLAMACPP_DIR=../llama.cpp
|
||||
CMAKE_DEFS=""
|
||||
CMAKE_TARGETS="--target ext_server"
|
||||
CMAKE_TARGETS="--target ollama_llama_server"
|
||||
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
|
||||
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on ${CMAKE_DEFS}"
|
||||
else
|
||||
@@ -81,27 +81,24 @@ apply_patches() {
|
||||
build() {
|
||||
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
|
||||
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
|
||||
mkdir -p ${BUILD_DIR}/lib/
|
||||
ls ${BUILD_DIR}
|
||||
g++ -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.${LIB_EXT} \
|
||||
${GCC_ARCH} \
|
||||
${WHOLE_ARCHIVE} ${BUILD_DIR}/ext_server/libext_server.a ${NO_WHOLE_ARCHIVE} \
|
||||
${BUILD_DIR}/common/libcommon.a \
|
||||
${BUILD_DIR}/libllama.a \
|
||||
-Wl,-rpath,\$ORIGIN \
|
||||
-lpthread -ldl -lm \
|
||||
${EXTRA_LIBS}
|
||||
}
|
||||
|
||||
compress_libs() {
|
||||
compress() {
|
||||
echo "Compressing payloads to reduce overall binary size..."
|
||||
pids=""
|
||||
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
|
||||
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
|
||||
gzip -n --best -f ${lib} &
|
||||
rm -rf ${BUILD_DIR}/bin/*.gz
|
||||
for f in ${BUILD_DIR}/bin/* ; do
|
||||
gzip -n --best -f ${f} &
|
||||
pids+=" $!"
|
||||
done
|
||||
echo
|
||||
# check for lib directory
|
||||
if [ -d ${BUILD_DIR}/lib ]; then
|
||||
for f in ${BUILD_DIR}/lib/* ; do
|
||||
gzip -n --best -f ${f} &
|
||||
pids+=" $!"
|
||||
done
|
||||
fi
|
||||
echo
|
||||
for pid in ${pids}; do
|
||||
wait $pid
|
||||
done
|
||||
|
||||
@@ -18,21 +18,31 @@ sign() {
|
||||
fi
|
||||
}
|
||||
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
|
||||
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_METAL_EMBED_LIBRARY=on"
|
||||
|
||||
case "${GOARCH}" in
|
||||
"amd64")
|
||||
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
|
||||
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
|
||||
#
|
||||
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
|
||||
@@ -40,11 +50,11 @@ case "${GOARCH}" in
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
|
||||
#
|
||||
# ~2013 CPU Dynamic library
|
||||
@@ -52,20 +62,30 @@ case "${GOARCH}" in
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
;;
|
||||
"arm64")
|
||||
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DBUILD_SHARED_LIBS=off -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/darwin/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_METAL_EMBED_LIBRARY=on -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
|
||||
BUILD_DIR="../build/darwin/${ARCH}/metal"
|
||||
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
|
||||
build
|
||||
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
|
||||
compress_libs
|
||||
sign ${BUILD_DIR}/lib/libext_server.dylib
|
||||
compress
|
||||
;;
|
||||
*)
|
||||
echo "GOARCH must be set"
|
||||
@@ -75,3 +95,4 @@ case "${GOARCH}" in
|
||||
esac
|
||||
|
||||
cleanup
|
||||
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
|
||||
|
||||
@@ -57,16 +57,31 @@ init_vars
|
||||
git_module_setup
|
||||
apply_patches
|
||||
|
||||
|
||||
init_vars
|
||||
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "static" ]; then
|
||||
# Static build for linking into the Go binary
|
||||
init_vars
|
||||
CMAKE_TARGETS="--target llama --target ggml"
|
||||
CMAKE_DEFS="-DBUILD_SHARED_LIBS=off -DLLAMA_NATIVE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="../build/linux/${ARCH}_static"
|
||||
echo "Building static library"
|
||||
build
|
||||
fi
|
||||
|
||||
|
||||
# Users building from source can tune the exact flags we pass to cmake for configuring
|
||||
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
|
||||
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
|
||||
init_vars
|
||||
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
|
||||
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||
echo "Building custom CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
else
|
||||
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
@@ -83,11 +98,12 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu"
|
||||
echo "Building LCD CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
if [ "${ARCH}" == "x86_64" ]; then
|
||||
@@ -101,10 +117,10 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx"
|
||||
echo "Building AVX CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
|
||||
@@ -114,10 +130,10 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
|
||||
#
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cpu_avx2"
|
||||
echo "Building AVX2 CPU"
|
||||
build
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
@@ -157,7 +173,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
ARM64_DEFS="-DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_CUDA_F16=off"
|
||||
fi
|
||||
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} ${ARM64_DEFS}"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/cuda${CUDA_VARIANT}"
|
||||
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
|
||||
build
|
||||
|
||||
@@ -165,20 +181,20 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
|
||||
#
|
||||
# TODO - in the future we may shift to packaging these separately and conditionally
|
||||
# downloading them in the install script.
|
||||
DEPS="$(ldd ${BUILD_DIR}/lib/libext_server.so )"
|
||||
DEPS="$(ldd ${BUILD_DIR}/bin/ollama_llama_server )"
|
||||
for lib in libcudart.so libcublas.so libcublasLt.so ; do
|
||||
DEP=$(echo "${DEPS}" | grep ${lib} | cut -f1 -d' ' | xargs || true)
|
||||
if [ -n "${DEP}" -a -e "${CUDA_LIB_DIR}/${DEP}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
|
||||
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/bin/"
|
||||
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
|
||||
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/bin/"
|
||||
elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/"
|
||||
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/bin/"
|
||||
else
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
|
||||
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/bin/"
|
||||
fi
|
||||
done
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
fi
|
||||
|
||||
@@ -201,23 +217,24 @@ if [ -d "${ROCM_PATH}" ]; then
|
||||
fi
|
||||
init_vars
|
||||
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
|
||||
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||
BUILD_DIR="../build/linux/${ARCH}/rocm${ROCM_VARIANT}"
|
||||
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,\$ORIGIN/../../rocm/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
|
||||
build
|
||||
|
||||
# Record the ROCM dependencies
|
||||
rm -f "${BUILD_DIR}/lib/deps.txt"
|
||||
touch "${BUILD_DIR}/lib/deps.txt"
|
||||
for dep in $(ldd "${BUILD_DIR}/lib/libext_server.so" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e rocm -e amdgpu -e libtinfo ); do
|
||||
echo "${dep}" >> "${BUILD_DIR}/lib/deps.txt"
|
||||
rm -f "${BUILD_DIR}/bin/deps.txt"
|
||||
touch "${BUILD_DIR}/bin/deps.txt"
|
||||
for dep in $(ldd "${BUILD_DIR}/bin/ollama_llama_server" | grep "=>" | cut -f2 -d= | cut -f2 -d' ' | grep -e rocm -e amdgpu -e libtinfo ); do
|
||||
echo "${dep}" >> "${BUILD_DIR}/bin/deps.txt"
|
||||
done
|
||||
# bomb out if for some reason we didn't get a few deps
|
||||
if [ $(cat "${BUILD_DIR}/lib/deps.txt" | wc -l ) -lt 8 ] ; then
|
||||
cat "${BUILD_DIR}/lib/deps.txt"
|
||||
if [ $(cat "${BUILD_DIR}/bin/deps.txt" | wc -l ) -lt 8 ] ; then
|
||||
cat "${BUILD_DIR}/bin/deps.txt"
|
||||
echo "ERROR: deps file short"
|
||||
exit 1
|
||||
fi
|
||||
compress_libs
|
||||
compress
|
||||
fi
|
||||
|
||||
cleanup
|
||||
echo "go generate completed. LLM runners: $(cd ${BUILD_DIR}/..; echo *)"
|
||||
|
||||
@@ -33,7 +33,7 @@ function init_vars {
|
||||
"-DBUILD_SHARED_LIBS=on",
|
||||
"-DLLAMA_NATIVE=off"
|
||||
)
|
||||
$script:cmakeTargets = @("ext_server")
|
||||
$script:cmakeTargets = @("ollama_llama_server")
|
||||
$script:ARCH = "amd64" # arm not yet supported.
|
||||
if ($env:CGO_CFLAGS -contains "-g") {
|
||||
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on", "-DCMAKE_BUILD_TYPE=RelWithDebInfo")
|
||||
@@ -97,16 +97,14 @@ function apply_patches {
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
foreach ($file in $filePaths) {
|
||||
git checkout $file
|
||||
git -C "${script:llamacppDir}" checkout $file
|
||||
}
|
||||
}
|
||||
|
||||
# Apply each patch
|
||||
foreach ($patch in $patches) {
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
git apply $patch.FullName
|
||||
git -C "${script:llamacppDir}" apply $patch.FullName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -115,41 +113,41 @@ function build {
|
||||
& cmake --version
|
||||
& cmake -S "${script:llamacppDir}" -B $script:buildDir $script:cmakeDefs
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })"
|
||||
write-host "building with: cmake --build $script:buildDir --config $script:config $($script:cmakeTargets | ForEach-Object { `"--target`", $_ })"
|
||||
& cmake --build $script:buildDir --config $script:config ($script:cmakeTargets | ForEach-Object { "--target", $_ })
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
|
||||
function install {
|
||||
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
|
||||
md "${script:buildDir}/lib" -ea 0 > $null
|
||||
cp "${script:buildDir}/bin/${script:config}/ext_server.dll" "${script:buildDir}/lib"
|
||||
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
|
||||
# Display the dll dependencies in the build log
|
||||
if ($script:DUMPBIN -ne $null) {
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
|
||||
# Rearrange output to be consistent between different generators
|
||||
if ($null -ne ${script:config} -And (test-path -path "${script:buildDir}/bin/${script:config}" ) ) {
|
||||
mv -force "${script:buildDir}/bin/${script:config}/*" "${script:buildDir}/bin/"
|
||||
remove-item "${script:buildDir}/bin/${script:config}"
|
||||
}
|
||||
}
|
||||
|
||||
function sign {
|
||||
if ("${env:KEY_CONTAINER}") {
|
||||
write-host "Signing ${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in (get-childitem "${script:buildDir}/lib/*.dll")){
|
||||
& "${script:SignTool}" sign /v /debug /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
write-host "Signing ${script:buildDir}/bin/*.exe ${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in @(get-childitem "${script:buildDir}/bin/*.exe") + @(get-childitem "${script:buildDir}/bin/*.dll")){
|
||||
& "${script:SignTool}" sign /v /fd sha256 /t http://timestamp.digicert.com /f "${script:OLLAMA_CERT}" `
|
||||
/csp "Google Cloud KMS Provider" /kc "${env:KEY_CONTAINER}" $file
|
||||
if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function compress_libs {
|
||||
function compress {
|
||||
if ($script:GZIP -eq $null) {
|
||||
write-host "gzip not installed, not compressing files"
|
||||
return
|
||||
}
|
||||
write-host "Compressing binaries..."
|
||||
$binaries = dir "${script:buildDir}/bin/*.exe"
|
||||
foreach ($file in $binaries) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
|
||||
write-host "Compressing dlls..."
|
||||
$libs = dir "${script:buildDir}/lib/*.dll"
|
||||
foreach ($file in $libs) {
|
||||
$binaries = dir "${script:buildDir}/bin/*.dll"
|
||||
foreach ($file in $dlls) {
|
||||
& "$script:GZIP" --best -f $file
|
||||
}
|
||||
}
|
||||
@@ -164,14 +162,11 @@ function cleanup {
|
||||
}
|
||||
|
||||
# Checkout each file
|
||||
Set-Location -Path ${script:llamacppDir}
|
||||
foreach ($file in $filePaths) {
|
||||
git checkout $file
|
||||
git -C "${script:llamacppDir}" checkout $file
|
||||
}
|
||||
git -C "${script:llamacppDir}" checkout CMakeLists.txt
|
||||
}
|
||||
Set-Location "${script:llamacppDir}/"
|
||||
git checkout CMakeLists.txt
|
||||
|
||||
}
|
||||
|
||||
init_vars
|
||||
@@ -179,7 +174,6 @@ git_module_setup
|
||||
apply_patches
|
||||
|
||||
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
|
||||
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
|
||||
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
|
||||
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
|
||||
|
||||
@@ -187,32 +181,46 @@ $script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on")
|
||||
|
||||
if ($null -eq ${env:OLLAMA_SKIP_CPU_GENERATE}) {
|
||||
|
||||
# GCC build for direct linking into the Go binary
|
||||
init_vars
|
||||
$script:cmakeTargets = @("llama", "ggml")
|
||||
$script:cmakeDefs = @(
|
||||
"-G", "MinGW Makefiles"
|
||||
"-DBUILD_SHARED_LIBS=off",
|
||||
"-DLLAMA_NATIVE=off",
|
||||
"-DLLAMA_AVX=off",
|
||||
"-DLLAMA_AVX2=off",
|
||||
"-DLLAMA_AVX512=off",
|
||||
"-DLLAMA_F16C=off",
|
||||
"-DLLAMA_FMA=off")
|
||||
$script:buildDir="../build/windows/${script:ARCH}_static"
|
||||
write-host "Building static library"
|
||||
build
|
||||
|
||||
# remaining llama.cpp builds use MSVC
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu"
|
||||
write-host "Building LCD CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx"
|
||||
write-host "Building AVX CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
|
||||
init_vars
|
||||
$script:cmakeDefs = $script:commonCpuDefs + @("-A", "x64", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx2"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cpu_avx2"
|
||||
write-host "Building AVX2 CPU"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
} else {
|
||||
write-host "Skipping CPU generation step as requested"
|
||||
}
|
||||
@@ -225,13 +233,11 @@ if ($null -ne $script:CUDA_LIB_DIR) {
|
||||
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
|
||||
}
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
|
||||
$script:cmakeDefs += @("-A", "x64", "-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DCUDAToolkit_INCLUDE_DIR=$script:CUDA_INCLUDE_DIR", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
|
||||
write-host "Building CUDA"
|
||||
build
|
||||
install
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
}
|
||||
|
||||
if ($null -ne $env:HIP_PATH) {
|
||||
@@ -241,7 +247,7 @@ if ($null -ne $env:HIP_PATH) {
|
||||
}
|
||||
|
||||
init_vars
|
||||
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
||||
$script:buildDir="../build/windows/${script:ARCH}/rocm$script:ROCM_VARIANT"
|
||||
$script:cmakeDefs += @(
|
||||
"-G", "Ninja",
|
||||
"-DCMAKE_C_COMPILER=clang.exe",
|
||||
@@ -264,13 +270,13 @@ if ($null -ne $env:HIP_PATH) {
|
||||
build
|
||||
# Ninja doesn't prefix with config name
|
||||
${script:config}=""
|
||||
install
|
||||
if ($null -ne $script:DUMPBIN) {
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
|
||||
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/ollama_llama_server.exe" | select-string ".dll"
|
||||
}
|
||||
sign
|
||||
compress_libs
|
||||
compress
|
||||
}
|
||||
|
||||
|
||||
cleanup
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\llama.cpp\build\windows\${script:ARCH})"
|
||||
write-host "`ngo generate completed. LLM runners: $(get-childitem -path ${script:SRC_DIR}\llm\build\windows\${script:ARCH})"
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package generate
|
||||
|
||||
//go:generate sh ./gen_darwin.sh
|
||||
//go:generate bash ./gen_darwin.sh
|
||||
|
||||
86
llm/ggla.go
86
llm/ggla.go
@@ -7,16 +7,18 @@ import (
|
||||
"slices"
|
||||
)
|
||||
|
||||
type ContainerGGLA struct {
|
||||
type containerGGLA struct {
|
||||
version uint32
|
||||
}
|
||||
|
||||
func (c *ContainerGGLA) Name() string {
|
||||
func (c *containerGGLA) Name() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
binary.Read(rs, binary.LittleEndian, &c.version)
|
||||
func (c *containerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
if err := binary.Read(rs, binary.LittleEndian, &c.version); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch c.version {
|
||||
case 1:
|
||||
@@ -24,37 +26,45 @@ func (c *ContainerGGLA) Decode(rs io.ReadSeeker) (model, error) {
|
||||
return nil, errors.New("invalid version")
|
||||
}
|
||||
|
||||
model := newModelGGLA(c)
|
||||
model := newGGLA(c)
|
||||
err := model.decode(rs)
|
||||
return model, err
|
||||
}
|
||||
|
||||
type ModelGGLA struct {
|
||||
*ContainerGGLA
|
||||
type ggla struct {
|
||||
*containerGGLA
|
||||
|
||||
kv KV
|
||||
tensors []Tensor
|
||||
tensors []*Tensor
|
||||
}
|
||||
|
||||
func newModelGGLA(container *ContainerGGLA) *ModelGGLA {
|
||||
return &ModelGGLA{
|
||||
ContainerGGLA: container,
|
||||
func newGGLA(container *containerGGLA) *ggla {
|
||||
return &ggla{
|
||||
containerGGLA: container,
|
||||
kv: make(KV),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
|
||||
func (llm *ggla) KV() KV {
|
||||
return llm.kv
|
||||
}
|
||||
|
||||
func (llm *ggla) Tensors() []*Tensor {
|
||||
return llm.tensors
|
||||
}
|
||||
|
||||
func (llm *ggla) decode(rs io.ReadSeeker) error {
|
||||
var r uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &r); err != nil {
|
||||
return err
|
||||
}
|
||||
m.kv["r"] = r
|
||||
llm.kv["r"] = r
|
||||
|
||||
var alpha uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &alpha); err != nil {
|
||||
return err
|
||||
}
|
||||
m.kv["alpha"] = alpha
|
||||
llm.kv["alpha"] = alpha
|
||||
|
||||
for {
|
||||
var dims uint32
|
||||
@@ -109,54 +119,10 @@ func (m *ModelGGLA) decode(rs io.ReadSeeker) error {
|
||||
|
||||
t.Offset = uint64(offset)
|
||||
|
||||
if _, err := rs.Seek(int64(t.Size()), io.SeekCurrent); err != nil {
|
||||
if _, err := rs.Seek(int64(t.size()), io.SeekCurrent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
m.tensors = append(m.tensors, t)
|
||||
llm.tensors = append(llm.tensors, &t)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) KV() KV {
|
||||
return m.kv
|
||||
}
|
||||
|
||||
func (m *ModelGGLA) Tensor() []Tensor {
|
||||
return m.tensors
|
||||
}
|
||||
|
||||
func (*ModelGGLA) ModelFamily() string {
|
||||
return "ggla"
|
||||
}
|
||||
|
||||
func (*ModelGGLA) ModelType() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) FileType() string {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumLayers() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumGQA() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumEmbed() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumHead() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumHeadKv() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (*ModelGGLA) NumCtx() uint32 {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
184
llm/ggml.go
184
llm/ggml.go
@@ -3,14 +3,24 @@ package llm
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type GGML struct {
|
||||
container
|
||||
model
|
||||
}
|
||||
|
||||
Size int64
|
||||
func (ggml *GGML) LayerSize(prefix string) (n int64) {
|
||||
for _, t := range ggml.Tensors() {
|
||||
if strings.HasPrefix(t.Name, prefix) {
|
||||
n += int64(t.size())
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
@@ -90,15 +100,148 @@ func fileType(fileType uint32) string {
|
||||
}
|
||||
|
||||
type model interface {
|
||||
ModelFamily() string
|
||||
ModelType() string
|
||||
FileType() string
|
||||
NumLayers() uint32
|
||||
NumGQA() uint32
|
||||
NumEmbed() uint32
|
||||
NumHead() uint32
|
||||
NumHeadKv() uint32
|
||||
NumCtx() uint32
|
||||
KV() KV
|
||||
Tensors() []*Tensor
|
||||
}
|
||||
|
||||
type KV map[string]any
|
||||
|
||||
func (kv KV) u64(key string) uint64 {
|
||||
switch v := kv[key].(type) {
|
||||
case uint64:
|
||||
return v
|
||||
case uint32:
|
||||
return uint64(v)
|
||||
case float64:
|
||||
return uint64(v)
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (kv KV) Architecture() string {
|
||||
if s, ok := kv["general.architecture"].(string); ok {
|
||||
return s
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (kv KV) ParameterCount() uint64 {
|
||||
return kv.u64("general.parameter_count")
|
||||
}
|
||||
|
||||
func (kv KV) FileType() string {
|
||||
if u64 := kv.u64("general.file_type"); u64 > 0 {
|
||||
return fileType(uint32(u64))
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (kv KV) BlockCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.block_count", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCount() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) HeadCountKV() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.attention.head_count_kv", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) GQA() uint64 {
|
||||
if headCountKV := kv.HeadCountKV(); headCountKV > 0 {
|
||||
return kv.HeadCount() / headCountKV
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
func (kv KV) EmbeddingLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.embedding_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
func (kv KV) ContextLength() uint64 {
|
||||
return kv.u64(fmt.Sprintf("%s.context_length", kv.Architecture()))
|
||||
}
|
||||
|
||||
type Tensor struct {
|
||||
Name string `json:"name"`
|
||||
Kind uint32 `json:"kind"`
|
||||
Offset uint64 `json:"-"`
|
||||
|
||||
// Shape is the number of elements in each dimension
|
||||
Shape []uint64 `json:"shape"`
|
||||
|
||||
io.WriterTo `json:"-"`
|
||||
}
|
||||
|
||||
func (t Tensor) blockSize() uint64 {
|
||||
switch {
|
||||
case t.Kind < 2:
|
||||
return 1
|
||||
case t.Kind < 10:
|
||||
return 32
|
||||
default:
|
||||
return 256
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) typeSize() uint64 {
|
||||
blockSize := t.blockSize()
|
||||
|
||||
switch t.Kind {
|
||||
case 0: // FP32
|
||||
return 4
|
||||
case 1: // FP16
|
||||
return 2
|
||||
case 2: // Q4_0
|
||||
return 2 + blockSize/2
|
||||
case 3: // Q4_1
|
||||
return 2 + 2 + blockSize/2
|
||||
case 6: // Q5_0
|
||||
return 2 + 4 + blockSize/2
|
||||
case 7: // Q5_1
|
||||
return 2 + 2 + 4 + blockSize/2
|
||||
case 8: // Q8_0
|
||||
return 2 + blockSize
|
||||
case 9: // Q8_1
|
||||
return 4 + 4 + blockSize
|
||||
case 10: // Q2_K
|
||||
return blockSize/16 + blockSize/4 + 2 + 2
|
||||
case 11: // Q3_K
|
||||
return blockSize/8 + blockSize/4 + 12 + 2
|
||||
case 12: // Q4_K
|
||||
return 2 + 2 + 12 + blockSize/2
|
||||
case 13: // Q5_K
|
||||
return 2 + 2 + 12 + blockSize/8 + blockSize/2
|
||||
case 14: // Q6_K
|
||||
return blockSize/2 + blockSize/4 + blockSize/16 + 2
|
||||
case 15: // Q8_K
|
||||
return 2 + blockSize + 2*blockSize/16
|
||||
case 16: // IQ2_XXS
|
||||
return 2 + 2*blockSize/8
|
||||
case 17: // IQ2_XS
|
||||
return 2 + 2*blockSize/8 + blockSize/32
|
||||
case 18: // IQ3_XXS
|
||||
return 2 + 3*blockSize/8
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
func (t Tensor) parameters() uint64 {
|
||||
var count uint64 = 1
|
||||
for _, n := range t.Shape {
|
||||
count *= n
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (t Tensor) size() uint64 {
|
||||
return t.parameters() * t.typeSize() / t.blockSize()
|
||||
}
|
||||
|
||||
type container interface {
|
||||
@@ -122,42 +265,41 @@ const (
|
||||
|
||||
var ErrUnsupportedFormat = errors.New("unsupported model format")
|
||||
|
||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, error) {
|
||||
func DecodeGGML(rs io.ReadSeeker) (*GGML, int64, error) {
|
||||
var magic uint32
|
||||
if err := binary.Read(rs, binary.LittleEndian, &magic); err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var c container
|
||||
switch magic {
|
||||
case FILE_MAGIC_GGML, FILE_MAGIC_GGMF, FILE_MAGIC_GGJT:
|
||||
return nil, ErrUnsupportedFormat
|
||||
return nil, 0, ErrUnsupportedFormat
|
||||
case FILE_MAGIC_GGLA:
|
||||
c = &ContainerGGLA{}
|
||||
c = &containerGGLA{}
|
||||
case FILE_MAGIC_GGUF_LE:
|
||||
c = &ContainerGGUF{ByteOrder: binary.LittleEndian}
|
||||
c = &containerGGUF{ByteOrder: binary.LittleEndian}
|
||||
case FILE_MAGIC_GGUF_BE:
|
||||
c = &ContainerGGUF{ByteOrder: binary.BigEndian}
|
||||
c = &containerGGUF{ByteOrder: binary.BigEndian}
|
||||
default:
|
||||
return nil, errors.New("invalid file magic")
|
||||
return nil, 0, errors.New("invalid file magic")
|
||||
}
|
||||
|
||||
model, err := c.Decode(rs)
|
||||
if errors.Is(err, io.EOF) {
|
||||
// noop
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
offset, err := rs.Seek(0, io.SeekCurrent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// final model type
|
||||
return &GGML{
|
||||
container: c,
|
||||
model: model,
|
||||
Size: offset,
|
||||
}, nil
|
||||
}, offset, nil
|
||||
}
|
||||
|
||||
1243
llm/gguf.go
1243
llm/gguf.go
File diff suppressed because it is too large
Load Diff
Submodule llm/llama.cpp updated: ad3a0505e3...37e7854c10
100
llm/llama.go
100
llm/llama.go
@@ -1,100 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
)
|
||||
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
var payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
|
||||
|
||||
type prediction struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
const maxRetries = 3
|
||||
|
||||
type PredictOpts struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
type PredictResult struct {
|
||||
Content string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
182
llm/llm.go
182
llm/llm.go
@@ -1,175 +1,15 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
// #cgo CFLAGS: -Illama.cpp
|
||||
// #cgo darwin,arm64 LDFLAGS: ${SRCDIR}/build/darwin/arm64_static/libllama.a -lstdc++
|
||||
// #cgo darwin,amd64 LDFLAGS: ${SRCDIR}/build/darwin/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo windows,amd64 LDFLAGS: ${SRCDIR}/build/windows/amd64_static/libllama.a -static -lstdc++
|
||||
// #cgo linux,amd64 LDFLAGS: ${SRCDIR}/build/linux/x86_64_static/libllama.a -lstdc++
|
||||
// #cgo linux,arm64 LDFLAGS: ${SRCDIR}/build/linux/arm64_static/libllama.a -lstdc++
|
||||
// #include "llama.h"
|
||||
import "C"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
type LLM interface {
|
||||
Predict(context.Context, PredictOpts, func(PredictResult)) error
|
||||
Embedding(context.Context, string) ([]float64, error)
|
||||
Encode(context.Context, string) ([]int, error)
|
||||
Decode(context.Context, []int) (string, error)
|
||||
Close()
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
|
||||
func New(model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ggml, err := DecodeGGML(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.NumCtx > int(ggml.NumCtx()) {
|
||||
slog.Warn(fmt.Sprintf("requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx()))
|
||||
opts.NumCtx = int(ggml.NumCtx())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
vram, _ := gpu.CheckVRAM()
|
||||
size := ggml.Size
|
||||
|
||||
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(max(ggml.NumHead(), 1))
|
||||
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.NumGQA()) * kv / 6
|
||||
|
||||
// certain model architectures don't support gpu inference yet
|
||||
if slices.Contains(cpuOnlyFamilies, ggml.ModelFamily()) {
|
||||
opts.NumGPU = 0
|
||||
}
|
||||
|
||||
info := gpu.GetGPUInfo()
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
if opts.NumGPU == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
if size+kv+graph > vram {
|
||||
slog.Info("not enough vram available, setting num_gpu=0")
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// TODO: implement layer splitting on macOS
|
||||
opts.NumGPU = 999
|
||||
default:
|
||||
if info.Library == "cpu" {
|
||||
slog.Info("GPU not available, falling back to CPU")
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
// don't use GPU at all if no layers are loaded
|
||||
if opts.NumGPU == 0 {
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
break
|
||||
}
|
||||
|
||||
// user-defined GPU count
|
||||
if opts.NumGPU != -1 {
|
||||
break
|
||||
}
|
||||
|
||||
// the "main" GPU needs the most memory and determines the limit
|
||||
// of how many layers can be loaded. It needs to fit:
|
||||
// 1. the full compute graph allocation for all devices (graph)
|
||||
// 2. the proportional kv cache for all devices (kv * % layers)
|
||||
// 3. the proportional model (size * % layers / # devices)
|
||||
// This estimates the number of layers
|
||||
maxlayers := int64(ggml.NumLayers()) + 1
|
||||
devices := int64(info.DeviceCount)
|
||||
avg := vram / devices
|
||||
layers := maxlayers * (avg - graph) / (kv + size/devices)
|
||||
if layers > maxlayers {
|
||||
layers = maxlayers
|
||||
}
|
||||
|
||||
// 1 + 2 must fit on the main gpu
|
||||
min := graph + kv*layers/maxlayers
|
||||
if layers <= 0 || min > avg {
|
||||
slog.Info("not enough vram available, falling back to CPU only")
|
||||
info.Library = "cpu"
|
||||
info.Variant = gpu.GetCPUVariant()
|
||||
opts.NumGPU = 0
|
||||
break
|
||||
}
|
||||
|
||||
opts.NumGPU = int(layers)
|
||||
}
|
||||
|
||||
opts.RopeFrequencyBase = 0.0
|
||||
opts.RopeFrequencyScale = 0.0
|
||||
return newLlmServer(info, model, adapters, projectors, opts)
|
||||
}
|
||||
|
||||
// Give any native cgo implementations an opportunity to initialize
|
||||
func Init() error {
|
||||
return nativeInit()
|
||||
}
|
||||
|
||||
func newLlmServer(gpuInfo gpu.GpuInfo, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
|
||||
dynLibs := getDynLibs(gpuInfo)
|
||||
|
||||
// Check to see if the user has requested a specific library instead of auto-detecting
|
||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
||||
if demandLib != "" {
|
||||
libPath := availableDynLibs[demandLib]
|
||||
if libPath == "" {
|
||||
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
|
||||
} else {
|
||||
slog.Info(fmt.Sprintf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib))
|
||||
dynLibs = []string{libPath}
|
||||
}
|
||||
}
|
||||
|
||||
// We stage into a temp directory, and if we've been idle for a while, it may have been reaped
|
||||
_, err := os.Stat(dynLibs[0])
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("%s has disappeared, reloading libraries", dynLibs[0]))
|
||||
err = nativeInit()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err2 := fmt.Errorf("unable to locate suitable llm library")
|
||||
for _, dynLib := range dynLibs {
|
||||
srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
|
||||
if err == nil {
|
||||
return srv, nil
|
||||
}
|
||||
slog.Warn(fmt.Sprintf("Failed to load dynamic library %s %s", dynLib, err))
|
||||
err2 = err
|
||||
}
|
||||
|
||||
return nil, err2
|
||||
// SystemInfo is an unused example of calling llama.cpp functions using CGo
|
||||
func SystemInfo() string {
|
||||
return C.GoString(C.llama_print_system_info())
|
||||
}
|
||||
|
||||
@@ -4,5 +4,5 @@ import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/linux/*/*/lib/*
|
||||
//go:embed build/darwin/x86_64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
@@ -4,5 +4,5 @@ import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/windows/*/*/lib/*.dll*
|
||||
//go:embed build/darwin/arm64/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
6
llm/llm_linux.go
Normal file
6
llm/llm_linux.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed build/linux/*/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
6
llm/llm_windows.go
Normal file
6
llm/llm_windows.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package llm
|
||||
|
||||
import "embed"
|
||||
|
||||
//go:embed build/windows/*/*/bin/*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,13 +0,0 @@
|
||||
diff --git a/llama.cpp b/llama.cpp
|
||||
index b27aa272..99372f9c 100644
|
||||
--- a/llama.cpp
|
||||
+++ b/llama.cpp
|
||||
@@ -9360,7 +9360,7 @@ struct llm_tokenizer_wpm {
|
||||
}
|
||||
|
||||
uint32_t to_lower(uint32_t code) {
|
||||
- static const std::locale locale("en_US.UTF-8");
|
||||
+ static const std::locale locale("");
|
||||
#if defined(_WIN32)
|
||||
if (code > 0xFFFF) {
|
||||
return code;
|
||||
211
llm/payload.go
Normal file
211
llm/payload.go
Normal file
@@ -0,0 +1,211 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
var errPayloadMissing = fmt.Errorf("expected payloads not included in this build of ollama")
|
||||
|
||||
func Init() error {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info("extracting embedded files", "dir", payloadsDir)
|
||||
binGlob := "build/*/*/*/bin/*"
|
||||
|
||||
// extract server libraries
|
||||
err = extractFiles(payloadsDir, binGlob)
|
||||
if err != nil {
|
||||
return fmt.Errorf("extract binaries: %v", err)
|
||||
}
|
||||
|
||||
var variants []string
|
||||
for v := range availableServers() {
|
||||
variants = append(variants, v)
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// binary names may contain an optional variant separated by '_'
|
||||
// For example, "ollama_rocm_v6" and "ollama_rocm_v5" or "ollama_cpu" and "ollama_cpu_avx2"
|
||||
// Any library without a variant is the lowest common denominator
|
||||
func availableServers() map[string]string {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
slog.Error("payload lookup error", "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
// glob payloadsDir for files that start with ollama_
|
||||
pattern := filepath.Join(payloadsDir, "*")
|
||||
|
||||
files, err := filepath.Glob(pattern)
|
||||
if err != nil {
|
||||
slog.Debug("could not glob", "pattern", pattern, "error", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
servers := make(map[string]string)
|
||||
for _, file := range files {
|
||||
slog.Debug("availableServers : found", "file", file)
|
||||
servers[filepath.Base(file)] = file
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// serversForGpu returns a list of compatible servers give the provided GPU
|
||||
// info, ordered by performance. assumes Init() has been called
|
||||
// TODO - switch to metadata based mapping
|
||||
func serversForGpu(info gpu.GpuInfo) []string {
|
||||
// glob workDir for files that start with ollama_
|
||||
availableServers := availableServers()
|
||||
requested := info.Library
|
||||
if info.Variant != "" {
|
||||
requested += "_" + info.Variant
|
||||
}
|
||||
|
||||
servers := []string{}
|
||||
|
||||
// exact match first
|
||||
for a := range availableServers {
|
||||
if a == requested {
|
||||
servers = []string{a}
|
||||
|
||||
if a == "metal" {
|
||||
return servers
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
alt := []string{}
|
||||
|
||||
// Then for GPUs load alternates and sort the list for consistent load ordering
|
||||
if info.Library != "cpu" {
|
||||
for a := range availableServers {
|
||||
if info.Library == strings.Split(a, "_")[0] && a != requested {
|
||||
alt = append(alt, a)
|
||||
}
|
||||
}
|
||||
|
||||
slices.Sort(alt)
|
||||
servers = append(servers, alt...)
|
||||
}
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if info.Library != "cpu" {
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != "" {
|
||||
for cmp := range availableServers {
|
||||
if cmp == "cpu_"+variant {
|
||||
servers = append(servers, cmp)
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
servers = append(servers, "cpu")
|
||||
}
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
servers = []string{"cpu"}
|
||||
}
|
||||
|
||||
return servers
|
||||
}
|
||||
|
||||
// extract extracts the embedded files to the target directory
|
||||
func extractFiles(targetDir string, glob string) error {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return errPayloadMissing
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("extractFiles could not mkdir %s: %v", targetDir, err)
|
||||
}
|
||||
|
||||
g := new(errgroup.Group)
|
||||
|
||||
// build/$OS/$GOARCH/$VARIANT/{bin,lib}/$FILE
|
||||
for _, file := range files {
|
||||
filename := file
|
||||
|
||||
variant := filepath.Base(filepath.Dir(filepath.Dir(filename)))
|
||||
|
||||
slog.Debug("extracting", "variant", variant, "file", filename)
|
||||
|
||||
g.Go(func() error {
|
||||
srcf, err := libEmbed.Open(filename)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer srcf.Close()
|
||||
|
||||
src := io.Reader(srcf)
|
||||
if strings.HasSuffix(filename, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", filename, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
variantDir := filepath.Join(targetDir, variant)
|
||||
if err := os.MkdirAll(variantDir, 0o755); err != nil {
|
||||
return fmt.Errorf("extractFiles could not mkdir %s: %v", variantDir, err)
|
||||
}
|
||||
|
||||
base := filepath.Base(filename)
|
||||
destFilename := filepath.Join(variantDir, base)
|
||||
|
||||
_, err = os.Stat(destFilename)
|
||||
switch {
|
||||
case errors.Is(err, os.ErrNotExist):
|
||||
destFile, err := os.OpenFile(destFilename, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", filename, err)
|
||||
}
|
||||
defer destFile.Close()
|
||||
if _, err := io.Copy(destFile, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", filename, err)
|
||||
}
|
||||
case err != nil:
|
||||
return fmt.Errorf("stat payload %s: %v", filename, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
// If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
|
||||
gpu.Cleanup()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,233 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/exp/slices"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
// Libraries names may contain an optional variant separated by '_'
|
||||
// For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2"
|
||||
// Any library without a variant is the lowest common denominator
|
||||
var availableDynLibs = map[string]string{}
|
||||
|
||||
const pathComponentCount = 7
|
||||
|
||||
// getDynLibs returns an ordered list of LLM libraries to try, starting with the best
|
||||
func getDynLibs(gpuInfo gpu.GpuInfo) []string {
|
||||
// Short circuit if we know we're using the default built-in (darwin only)
|
||||
if gpuInfo.Library == "default" {
|
||||
return []string{"default"}
|
||||
}
|
||||
// TODO - temporary until we have multiple CPU variations for Darwin
|
||||
// Short circuit on darwin with metal only
|
||||
if len(availableDynLibs) == 1 {
|
||||
if _, onlyMetal := availableDynLibs["metal"]; onlyMetal {
|
||||
return []string{availableDynLibs["metal"]}
|
||||
}
|
||||
}
|
||||
|
||||
exactMatch := ""
|
||||
dynLibs := []string{}
|
||||
altDynLibs := []string{}
|
||||
requested := gpuInfo.Library
|
||||
if gpuInfo.Variant != "" {
|
||||
requested += "_" + gpuInfo.Variant
|
||||
}
|
||||
// Try to find an exact match
|
||||
for cmp := range availableDynLibs {
|
||||
if requested == cmp {
|
||||
exactMatch = cmp
|
||||
dynLibs = []string{availableDynLibs[cmp]}
|
||||
break
|
||||
}
|
||||
}
|
||||
// Then for GPUs load alternates and sort the list for consistent load ordering
|
||||
if gpuInfo.Library != "cpu" {
|
||||
for cmp := range availableDynLibs {
|
||||
if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch {
|
||||
altDynLibs = append(altDynLibs, cmp)
|
||||
}
|
||||
}
|
||||
slices.Sort(altDynLibs)
|
||||
for _, altDynLib := range altDynLibs {
|
||||
dynLibs = append(dynLibs, availableDynLibs[altDynLib])
|
||||
}
|
||||
}
|
||||
|
||||
// Load up the best CPU variant if not primary requested
|
||||
if gpuInfo.Library != "cpu" {
|
||||
variant := gpu.GetCPUVariant()
|
||||
// If no variant, then we fall back to default
|
||||
// If we have a variant, try that if we find an exact match
|
||||
// Attempting to run the wrong CPU instructions will panic the
|
||||
// process
|
||||
if variant != "" {
|
||||
for cmp := range availableDynLibs {
|
||||
if cmp == "cpu_"+variant {
|
||||
dynLibs = append(dynLibs, availableDynLibs[cmp])
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
dynLibs = append(dynLibs, availableDynLibs["cpu"])
|
||||
}
|
||||
}
|
||||
|
||||
// Finally, if we didn't find any matches, LCD CPU FTW
|
||||
if len(dynLibs) == 0 {
|
||||
dynLibs = []string{availableDynLibs["cpu"]}
|
||||
}
|
||||
slog.Debug(fmt.Sprintf("ordered list of LLM libraries to try %v", dynLibs))
|
||||
return dynLibs
|
||||
}
|
||||
|
||||
func rocmDynLibPresent() bool {
|
||||
for dynLibName := range availableDynLibs {
|
||||
if strings.HasPrefix(dynLibName, "rocm") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func nativeInit() error {
|
||||
payloadsDir, err := gpu.PayloadsDir()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
slog.Info(fmt.Sprintf("Extracting dynamic libraries to %s ...", payloadsDir))
|
||||
|
||||
libs, err := extractDynamicLibs(payloadsDir, "llama.cpp/build/*/*/*/lib/*")
|
||||
if err != nil {
|
||||
if errors.Is(err, payloadMissing) {
|
||||
slog.Info(fmt.Sprintf("%s", payloadMissing))
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
for _, lib := range libs {
|
||||
// The last dir component is the variant name
|
||||
variant := filepath.Base(filepath.Dir(lib))
|
||||
availableDynLibs[variant] = lib
|
||||
}
|
||||
|
||||
if err := verifyDriverAccess(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Report which dynamic libraries we have loaded to assist troubleshooting
|
||||
variants := make([]string, len(availableDynLibs))
|
||||
i := 0
|
||||
for variant := range availableDynLibs {
|
||||
variants[i] = variant
|
||||
i++
|
||||
}
|
||||
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
|
||||
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractDynamicLibs(payloadsDir, glob string) ([]string, error) {
|
||||
files, err := fs.Glob(libEmbed, glob)
|
||||
if err != nil || len(files) == 0 {
|
||||
return nil, payloadMissing
|
||||
}
|
||||
|
||||
var mu sync.Mutex
|
||||
var libs []string
|
||||
var g errgroup.Group
|
||||
for _, file := range files {
|
||||
pathComps := strings.Split(file, "/")
|
||||
if len(pathComps) != pathComponentCount {
|
||||
slog.Error(fmt.Sprintf("unexpected payload components: %v", pathComps))
|
||||
continue
|
||||
}
|
||||
|
||||
file := file
|
||||
g.Go(func() error {
|
||||
// llama.cpp/build/$OS/$GOARCH/$VARIANT/lib/$LIBRARY
|
||||
// Include the variant in the path to avoid conflicts between multiple server libs
|
||||
targetDir := filepath.Join(payloadsDir, pathComps[pathComponentCount-3])
|
||||
srcFile, err := libEmbed.Open(file)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read payload %s: %v", file, err)
|
||||
}
|
||||
defer srcFile.Close()
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
return fmt.Errorf("create payload lib dir %s: %v", payloadsDir, err)
|
||||
}
|
||||
src := io.Reader(srcFile)
|
||||
filename := file
|
||||
if strings.HasSuffix(file, ".gz") {
|
||||
src, err = gzip.NewReader(src)
|
||||
if err != nil {
|
||||
return fmt.Errorf("decompress payload %s: %v", file, err)
|
||||
}
|
||||
filename = strings.TrimSuffix(filename, ".gz")
|
||||
}
|
||||
|
||||
destFile := filepath.Join(targetDir, filepath.Base(filename))
|
||||
if strings.Contains(destFile, "server") {
|
||||
mu.Lock()
|
||||
libs = append(libs, destFile)
|
||||
mu.Unlock()
|
||||
}
|
||||
|
||||
destFp, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
|
||||
if err != nil {
|
||||
return fmt.Errorf("write payload %s: %v", file, err)
|
||||
}
|
||||
defer destFp.Close()
|
||||
if _, err := io.Copy(destFp, src); err != nil {
|
||||
return fmt.Errorf("copy payload %s: %v", file, err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
err = g.Wait()
|
||||
if err != nil {
|
||||
// If we fail to extract, the payload dir is unusable, so cleanup whatever we extracted
|
||||
gpu.Cleanup()
|
||||
return nil, err
|
||||
}
|
||||
return libs, nil
|
||||
}
|
||||
|
||||
func verifyDriverAccess() error {
|
||||
if runtime.GOOS != "linux" {
|
||||
return nil
|
||||
}
|
||||
// Only check ROCm access if we have the dynamic lib loaded
|
||||
if rocmDynLibPresent() {
|
||||
// Verify we have permissions - either running as root, or we have group access to the driver
|
||||
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
|
||||
if err != nil {
|
||||
if errors.Is(err, fs.ErrPermission) {
|
||||
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
|
||||
} else if errors.Is(err, fs.ErrNotExist) {
|
||||
// expected behavior without a radeon card
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
|
||||
}
|
||||
fd.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/build/darwin/x86_64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,8 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/arm64/*/lib/*.dylib*
|
||||
var libEmbed embed.FS
|
||||
@@ -1,58 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/gpu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestGetDynLibs(t *testing.T) {
|
||||
availableDynLibs = map[string]string{
|
||||
"cpu": "X_cpu",
|
||||
}
|
||||
assert.Equal(t, false, rocmDynLibPresent())
|
||||
res := getDynLibs(gpu.GpuInfo{Library: "cpu"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"], res[0])
|
||||
|
||||
variant := gpu.GetCPUVariant()
|
||||
if variant != "" {
|
||||
variant = "_" + variant
|
||||
}
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm_v5": "X_rocm_v5",
|
||||
"rocm_v6": "X_rocm_v6",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 3)
|
||||
assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
|
||||
assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[0])
|
||||
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "default"})
|
||||
assert.Len(t, res, 1)
|
||||
assert.Equal(t, "default", res[0])
|
||||
|
||||
availableDynLibs = map[string]string{
|
||||
"rocm": "X_rocm_v5",
|
||||
"cpu" + variant: "X_cpu",
|
||||
}
|
||||
assert.Equal(t, true, rocmDynLibPresent())
|
||||
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
|
||||
assert.Len(t, res, 2)
|
||||
assert.Equal(t, availableDynLibs["rocm"], res[0])
|
||||
assert.Equal(t, availableDynLibs["cpu"+variant], res[1])
|
||||
}
|
||||
854
llm/server.go
Normal file
854
llm/server.go
Normal file
@@ -0,0 +1,854 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"log/slog"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/gpu"
|
||||
)
|
||||
|
||||
// LlamaServer is an instance of the llama.cpp server
|
||||
type LlamaServer struct {
|
||||
port int
|
||||
cmd *exec.Cmd
|
||||
done chan error // Channel to signal when the process exits
|
||||
status *StatusWriter
|
||||
options api.Options
|
||||
}
|
||||
|
||||
var cpuOnlyFamilies = []string{
|
||||
"mamba",
|
||||
}
|
||||
|
||||
func NewLlamaServer(model string, adapters, projectors []string, opts api.Options) (*LlamaServer, error) {
|
||||
if _, err := os.Stat(model); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
f, err := os.Open(model)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.NumCtx > int(ggml.KV().ContextLength()) {
|
||||
slog.Warn("requested context length is greater than model max context length", "requested", opts.NumCtx, "model", ggml.KV().ContextLength())
|
||||
opts.NumCtx = int(ggml.KV().ContextLength())
|
||||
}
|
||||
|
||||
if opts.NumCtx < 4 {
|
||||
opts.NumCtx = 4
|
||||
}
|
||||
|
||||
availableMemory, _ := gpu.CheckVRAM()
|
||||
info := gpu.GetGPUInfo()
|
||||
|
||||
usedMemory := info.MinimumMemory
|
||||
for _, projector := range projectors {
|
||||
usedMemory += projectorMemoryRequirements(projector)
|
||||
|
||||
// multimodal models require at least 2048 context
|
||||
opts.NumCtx = max(opts.NumCtx, 2048)
|
||||
}
|
||||
|
||||
// fp16 k,v = (1 (k) + 1 (v)) * sizeof(float16) * n_ctx * n_layer * n_embd / n_head * n_head_kv
|
||||
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.KV().BlockCount()) * int64(ggml.KV().EmbeddingLength()) / int64(ggml.KV().HeadCount()) * int64(ggml.KV().HeadCountKV())
|
||||
|
||||
// this amount is the overhead + tensors in memory
|
||||
// TODO: get this from the llama.cpp's graph calculations instead of
|
||||
// estimating it's 1/6 * kv_cache_size * num_gqa
|
||||
graph := int64(ggml.KV().GQA()) * kv / 6
|
||||
usedMemory += graph
|
||||
|
||||
if (usedMemory > availableMemory || slices.Contains(cpuOnlyFamilies, ggml.KV().Architecture())) && info.Library != "metal" {
|
||||
info.Library = "cpu"
|
||||
}
|
||||
|
||||
requiredMemory := usedMemory
|
||||
|
||||
var layers int
|
||||
for i := 0; i < int(ggml.KV().BlockCount()); i++ {
|
||||
layerMemory := ggml.LayerSize(fmt.Sprintf("blk.%d.", i)) + kv/int64(ggml.KV().BlockCount())
|
||||
requiredMemory += layerMemory
|
||||
|
||||
if availableMemory > usedMemory+layerMemory && (opts.NumGPU < 0 || layers < opts.NumGPU) {
|
||||
usedMemory += layerMemory
|
||||
layers++
|
||||
}
|
||||
}
|
||||
|
||||
memOutputLayer := ggml.LayerSize("output.")
|
||||
requiredMemory += memOutputLayer
|
||||
|
||||
// only offload output layer if all repeating layers are offloaded
|
||||
if layers >= int(ggml.KV().BlockCount()) && availableMemory > usedMemory+memOutputLayer {
|
||||
usedMemory += memOutputLayer
|
||||
layers++
|
||||
}
|
||||
|
||||
slog.Info(
|
||||
"offload to gpu",
|
||||
"layers", layers,
|
||||
"required", format.HumanBytes2(requiredMemory),
|
||||
"used", format.HumanBytes2(usedMemory),
|
||||
"available", format.HumanBytes2(availableMemory),
|
||||
"kv", format.HumanBytes2(kv),
|
||||
"graph", format.HumanBytes2(graph),
|
||||
)
|
||||
|
||||
if opts.NumGPU < 0 && info.Library != "cpu" {
|
||||
opts.NumGPU = layers
|
||||
}
|
||||
|
||||
if len(adapters) > 1 {
|
||||
return nil, errors.New("ollama supports only one lora adapter, but multiple were provided")
|
||||
}
|
||||
|
||||
availableServers := availableServers()
|
||||
servers := serversForGpu(info)
|
||||
|
||||
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
|
||||
if demandLib != "" {
|
||||
serverPath := availableServers[demandLib]
|
||||
if serverPath == "" {
|
||||
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
|
||||
} else {
|
||||
slog.Info("user override", "OLLAMA_LLM_LIBRARY", demandLib, "path", serverPath)
|
||||
servers = []string{demandLib}
|
||||
}
|
||||
}
|
||||
|
||||
if len(servers) == 0 {
|
||||
return nil, fmt.Errorf("no servers found for %v", info)
|
||||
}
|
||||
|
||||
params := []string{
|
||||
"--model", model,
|
||||
"--ctx-size", fmt.Sprintf("%d", opts.NumCtx),
|
||||
"--batch-size", fmt.Sprintf("%d", opts.NumBatch),
|
||||
"--embedding",
|
||||
}
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
params = append(params, "--log-format", "json")
|
||||
} else {
|
||||
params = append(params, "--log-disable")
|
||||
}
|
||||
|
||||
if opts.NumGPU >= 0 {
|
||||
params = append(params, "--n-gpu-layers", fmt.Sprintf("%d", opts.NumGPU))
|
||||
}
|
||||
|
||||
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
|
||||
params = append(params, "--verbose")
|
||||
}
|
||||
|
||||
if opts.MainGPU > 0 {
|
||||
params = append(params, "--main-gpu", fmt.Sprintf("%d", opts.MainGPU))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyBase > 0 {
|
||||
params = append(params, "--rope-freq-base", fmt.Sprintf("%f", opts.RopeFrequencyBase))
|
||||
}
|
||||
|
||||
if opts.RopeFrequencyScale > 0 {
|
||||
params = append(params, "--rope-freq-scale", fmt.Sprintf("%f", opts.RopeFrequencyScale))
|
||||
}
|
||||
|
||||
if len(adapters) > 0 {
|
||||
// TODO: applying multiple adapters is not supported by the llama.cpp server yet
|
||||
params = append(params, "--lora", adapters[0])
|
||||
}
|
||||
|
||||
if len(projectors) > 0 {
|
||||
// TODO: applying multiple projectors is not supported by the llama.cpp server yet
|
||||
params = append(params, "--mmproj", projectors[0])
|
||||
}
|
||||
|
||||
if opts.NumThread > 0 {
|
||||
params = append(params, "--threads", fmt.Sprintf("%d", opts.NumThread))
|
||||
}
|
||||
|
||||
if !opts.F16KV {
|
||||
params = append(params, "--memory-f32")
|
||||
}
|
||||
|
||||
if opts.UseMLock {
|
||||
params = append(params, "--mlock")
|
||||
}
|
||||
|
||||
if !opts.UseMMap {
|
||||
params = append(params, "--no-mmap")
|
||||
}
|
||||
|
||||
if opts.UseNUMA {
|
||||
params = append(params, "--numa")
|
||||
}
|
||||
|
||||
// Loop through potential servers
|
||||
var finalErr error
|
||||
for i := 0; i < len(servers); i++ {
|
||||
dir := availableServers[servers[i]]
|
||||
|
||||
// Find an availableServers port, retry on each iterration in case the failure was a port conflict race
|
||||
port := 0
|
||||
if a, err := net.ResolveTCPAddr("tcp", "localhost:0"); err == nil {
|
||||
var l *net.TCPListener
|
||||
if l, err = net.ListenTCP("tcp", a); err == nil {
|
||||
port = l.Addr().(*net.TCPAddr).Port
|
||||
l.Close()
|
||||
}
|
||||
}
|
||||
if port == 0 {
|
||||
slog.Debug("ResolveTCPAddr failed ", "error", err)
|
||||
port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range
|
||||
}
|
||||
finalParams := append(params, "--port", strconv.Itoa(port))
|
||||
|
||||
pathEnv := "LD_LIBRARY_PATH"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathEnv = "PATH"
|
||||
}
|
||||
// append the server directory to LD_LIBRARY_PATH/PATH
|
||||
libraryPaths := []string{dir}
|
||||
if libraryPath, ok := os.LookupEnv(pathEnv); ok {
|
||||
// Append our runner directory to the path
|
||||
// This will favor system libraries over our bundled library dependencies
|
||||
libraryPaths = append(filepath.SplitList(libraryPath), libraryPaths...)
|
||||
}
|
||||
|
||||
server := filepath.Join(dir, "ollama_llama_server")
|
||||
if runtime.GOOS == "windows" {
|
||||
server = server + ".exe"
|
||||
}
|
||||
|
||||
s := &LlamaServer{
|
||||
port: port,
|
||||
cmd: exec.Command(server, finalParams...),
|
||||
status: NewStatusWriter(os.Stderr),
|
||||
options: opts,
|
||||
}
|
||||
libEnv := fmt.Sprintf("%s=%s", pathEnv, strings.Join(libraryPaths, string(filepath.ListSeparator)))
|
||||
slog.Debug(libEnv)
|
||||
s.cmd.Env = append(os.Environ(), libEnv)
|
||||
s.cmd.Stdout = os.Stdout
|
||||
s.cmd.Stderr = s.status
|
||||
|
||||
slog.Info("starting llama server", "cmd", s.cmd.String())
|
||||
|
||||
if err = s.cmd.Start(); err != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
err = fmt.Errorf("error starting the external llama server: %v %s", err, msg)
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
|
||||
// reap subprocess when it exits
|
||||
go func() {
|
||||
// Exit status managed via getServerStatus
|
||||
_ = s.cmd.Wait()
|
||||
}()
|
||||
|
||||
if err = s.waitUntilRunning(); err != nil {
|
||||
slog.Error("error starting llama server", "server", servers[i], "error", err)
|
||||
s.Close()
|
||||
finalErr = err
|
||||
continue
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
slog.Error("unable to load any llama server", "error", finalErr)
|
||||
return nil, finalErr
|
||||
}
|
||||
|
||||
func projectorMemoryRequirements(filename string) int64 {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
ggml, _, err := DecodeGGML(file)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
prefixes := make(map[string]struct{})
|
||||
for _, layer := range ggml.Tensors() {
|
||||
parts := strings.Split(layer.Name, ".")
|
||||
prefixes[strings.Join(parts[:2], ".")] = struct{}{}
|
||||
}
|
||||
|
||||
var ask int64
|
||||
for prefix := range prefixes {
|
||||
ask += ggml.LayerSize(prefix)
|
||||
}
|
||||
|
||||
return ask
|
||||
}
|
||||
|
||||
type ServerStatus int
|
||||
|
||||
const ( // iota is reset to 0
|
||||
ServerStatusReady ServerStatus = iota
|
||||
ServerStatusNoSlotsAvaialble
|
||||
ServerStatusLoadingModel
|
||||
ServerStatusNotResponding
|
||||
ServerStatusError
|
||||
)
|
||||
|
||||
type ServerStatusResp struct {
|
||||
Status string `json:"status"`
|
||||
SlotsIdle int `json:"slots_idle"`
|
||||
SlotsProcessing int `json:"slots_processing"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) getServerStatus(ctx context.Context) (ServerStatus, error) {
|
||||
// Fail fast if its exited
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, fmt.Sprintf("http://127.0.0.1:%d/health", s.port), nil)
|
||||
if err != nil {
|
||||
return ServerStatusError, fmt.Errorf("error creating GET request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
if errors.Is(err, context.DeadlineExceeded) {
|
||||
return ServerStatusNotResponding, fmt.Errorf("server not responding")
|
||||
}
|
||||
return ServerStatusError, fmt.Errorf("health resp: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return ServerStatusError, fmt.Errorf("read health request: %w", err)
|
||||
}
|
||||
|
||||
var status ServerStatusResp
|
||||
if err := json.Unmarshal(body, &status); err != nil {
|
||||
return ServerStatusError, fmt.Errorf("health unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
switch status.Status {
|
||||
case "ok":
|
||||
return ServerStatusReady, nil
|
||||
case "no slot available":
|
||||
return ServerStatusNoSlotsAvaialble, nil
|
||||
case "loading model":
|
||||
return ServerStatusLoadingModel, nil
|
||||
default:
|
||||
return ServerStatusError, fmt.Errorf("server error: %+v", status)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Ping(ctx context.Context) error {
|
||||
_, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
slog.Debug("server unhealthy", "error", err)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) waitUntilRunning() error {
|
||||
start := time.Now()
|
||||
expiresAt := time.Now().Add(3 * time.Minute) // be generous with timeout, large models can take a while to load
|
||||
ticker := time.NewTicker(50 * time.Millisecond)
|
||||
defer ticker.Stop()
|
||||
|
||||
slog.Info("waiting for llama runner to start responding")
|
||||
var lastStatus ServerStatus = -1
|
||||
for {
|
||||
select {
|
||||
case err := <-s.done:
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("llama runner process has terminated: %v %s", err, msg)
|
||||
case <-ticker.C:
|
||||
if time.Now().After(expiresAt) {
|
||||
// timeout
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("timed out waiting for llama runner to start: %s", msg)
|
||||
}
|
||||
if s.cmd.ProcessState != nil {
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
return fmt.Errorf("llama runner process no longer running: %d %s", s.cmd.ProcessState.ExitCode(), msg)
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
|
||||
defer cancel()
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil && lastStatus != status {
|
||||
slog.Debug("server not yet available", "error", err)
|
||||
lastStatus = status
|
||||
continue
|
||||
}
|
||||
|
||||
switch status {
|
||||
case ServerStatusLoadingModel:
|
||||
// TODO - this state never seems to happen with the current server.cpp code (bug?)
|
||||
// it doesn't respond to the health endpoint until after the model is loaded
|
||||
slog.Debug("loading model")
|
||||
case ServerStatusReady:
|
||||
slog.Debug(fmt.Sprintf("llama runner started in %f seconds", time.Since(start).Seconds()))
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const jsonGrammar = `
|
||||
root ::= object
|
||||
value ::= object | array | string | number | ("true" | "false" | "null") ws
|
||||
|
||||
object ::=
|
||||
"{" ws (
|
||||
string ":" ws value
|
||||
("," ws string ":" ws value)*
|
||||
)? "}" ws
|
||||
|
||||
array ::=
|
||||
"[" ws (
|
||||
value
|
||||
("," ws value)*
|
||||
)? "]" ws
|
||||
|
||||
string ::=
|
||||
"\"" (
|
||||
[^"\\] |
|
||||
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F]) # escapes
|
||||
)* "\"" ws
|
||||
|
||||
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
|
||||
|
||||
# Optional space: by convention, applied in this grammar after literal chars when allowed
|
||||
ws ::= ([ \t\n] ws)?
|
||||
`
|
||||
|
||||
const maxBufferSize = 512 * format.KiloByte
|
||||
const maxRetries = 3
|
||||
|
||||
type ImageData struct {
|
||||
Data []byte `json:"data"`
|
||||
ID int `json:"id"`
|
||||
}
|
||||
|
||||
type completion struct {
|
||||
Content string `json:"content"`
|
||||
Model string `json:"model"`
|
||||
Prompt string `json:"prompt"`
|
||||
Stop bool `json:"stop"`
|
||||
|
||||
Timings struct {
|
||||
PredictedN int `json:"predicted_n"`
|
||||
PredictedMS float64 `json:"predicted_ms"`
|
||||
PromptN int `json:"prompt_n"`
|
||||
PromptMS float64 `json:"prompt_ms"`
|
||||
}
|
||||
}
|
||||
|
||||
type CompletionRequest struct {
|
||||
Prompt string
|
||||
Format string
|
||||
Images []ImageData
|
||||
Options api.Options
|
||||
}
|
||||
|
||||
type CompletionResponse struct {
|
||||
Content string
|
||||
Done bool
|
||||
PromptEvalCount int
|
||||
PromptEvalDuration time.Duration
|
||||
EvalCount int
|
||||
EvalDuration time.Duration
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error {
|
||||
request := map[string]any{
|
||||
"prompt": req.Prompt,
|
||||
"stream": true,
|
||||
"n_predict": req.Options.NumPredict,
|
||||
"n_keep": req.Options.NumKeep,
|
||||
"main_gpu": req.Options.MainGPU,
|
||||
"temperature": req.Options.Temperature,
|
||||
"top_k": req.Options.TopK,
|
||||
"top_p": req.Options.TopP,
|
||||
"tfs_z": req.Options.TFSZ,
|
||||
"typical_p": req.Options.TypicalP,
|
||||
"repeat_last_n": req.Options.RepeatLastN,
|
||||
"repeat_penalty": req.Options.RepeatPenalty,
|
||||
"presence_penalty": req.Options.PresencePenalty,
|
||||
"frequency_penalty": req.Options.FrequencyPenalty,
|
||||
"mirostat": req.Options.Mirostat,
|
||||
"mirostat_tau": req.Options.MirostatTau,
|
||||
"mirostat_eta": req.Options.MirostatEta,
|
||||
"penalize_nl": req.Options.PenalizeNewline,
|
||||
"seed": req.Options.Seed,
|
||||
"stop": req.Options.Stop,
|
||||
"image_data": req.Images,
|
||||
"cache_prompt": true,
|
||||
}
|
||||
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
} else if status != ServerStatusReady {
|
||||
return fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
if req.Format == "json" {
|
||||
request["grammar"] = jsonGrammar
|
||||
if !strings.Contains(strings.ToLower(req.Prompt), "json") {
|
||||
slog.Warn("Prompt does not specify that the LLM should response in JSON, but JSON format is expected. For best results specify that JSON is expected in the system prompt.")
|
||||
}
|
||||
}
|
||||
|
||||
retryDelay := 100 * time.Microsecond
|
||||
for retries := 0; retries < maxRetries; retries++ {
|
||||
if retries > 0 {
|
||||
time.Sleep(retryDelay) // wait before retrying
|
||||
retryDelay *= 2 // exponential backoff
|
||||
}
|
||||
|
||||
// Handling JSON marshaling with special characters unescaped.
|
||||
buffer := &bytes.Buffer{}
|
||||
enc := json.NewEncoder(buffer)
|
||||
enc.SetEscapeHTML(false)
|
||||
|
||||
if err := enc.Encode(request); err != nil {
|
||||
return fmt.Errorf("failed to marshal data: %v", err)
|
||||
}
|
||||
|
||||
endpoint := fmt.Sprintf("http://127.0.0.1:%d/completion", s.port)
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, buffer)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating POST request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("POST predict: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed reading llm error response: %w", err)
|
||||
}
|
||||
log.Printf("llm predict error: %s", bodyBytes)
|
||||
return fmt.Errorf("%s", bodyBytes)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
buf := make([]byte, 0, maxBufferSize)
|
||||
scanner.Buffer(buf, maxBufferSize)
|
||||
|
||||
retryNeeded := false
|
||||
// keep track of the last token generated, this is used to abort if the model starts looping
|
||||
var lastToken string
|
||||
var tokenRepeat int
|
||||
|
||||
for scanner.Scan() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
// This handles the request cancellation
|
||||
return ctx.Err()
|
||||
default:
|
||||
line := scanner.Bytes()
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
// try again on slot unavailable
|
||||
if bytes.Contains(line, []byte("slot unavailable")) {
|
||||
retryNeeded = true
|
||||
break
|
||||
}
|
||||
|
||||
evt, ok := bytes.CutPrefix(line, []byte("data: "))
|
||||
if !ok {
|
||||
return fmt.Errorf("error parsing llm response stream: %s", line)
|
||||
}
|
||||
|
||||
var c completion
|
||||
if err := json.Unmarshal(evt, &c); err != nil {
|
||||
return fmt.Errorf("error unmarshaling llm prediction response: %v", err)
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.TrimSpace(c.Content) == lastToken:
|
||||
tokenRepeat++
|
||||
default:
|
||||
lastToken = strings.TrimSpace(c.Content)
|
||||
tokenRepeat = 0
|
||||
}
|
||||
|
||||
// 30 picked as an arbitrary max token repeat limit, modify as needed
|
||||
if tokenRepeat > 30 {
|
||||
slog.Debug("prediction aborted, token repeat limit reached")
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
if c.Content != "" {
|
||||
fn(CompletionResponse{
|
||||
Content: c.Content,
|
||||
})
|
||||
}
|
||||
|
||||
if c.Stop {
|
||||
fn(CompletionResponse{
|
||||
Done: true,
|
||||
PromptEvalCount: c.Timings.PromptN,
|
||||
PromptEvalDuration: parseDurationMs(c.Timings.PromptMS),
|
||||
EvalCount: c.Timings.PredictedN,
|
||||
EvalDuration: parseDurationMs(c.Timings.PredictedMS),
|
||||
})
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
if strings.Contains(err.Error(), "unexpected EOF") {
|
||||
s.Close()
|
||||
msg := ""
|
||||
if s.status != nil && s.status.LastErrMsg != "" {
|
||||
msg = s.status.LastErrMsg
|
||||
}
|
||||
|
||||
return fmt.Errorf("an unknown error was encountered while running the model %s", msg)
|
||||
}
|
||||
return fmt.Errorf("error reading llm response: %v", err)
|
||||
}
|
||||
|
||||
if !retryNeeded {
|
||||
return nil // success
|
||||
}
|
||||
}
|
||||
|
||||
// should never reach here ideally
|
||||
return fmt.Errorf("max retries exceeded")
|
||||
}
|
||||
|
||||
type EmbeddingRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type EmbeddingResponse struct {
|
||||
Embedding []float64 `json:"embedding"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Embedding(ctx context.Context, prompt string) ([]float64, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: prompt})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error marshaling embed data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/embedding", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating embed request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do embedding request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading embed response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var embedding EmbeddingResponse
|
||||
if err := json.Unmarshal(body, &embedding); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal tokenize response: %w", err)
|
||||
}
|
||||
|
||||
return embedding.Embedding, nil
|
||||
}
|
||||
|
||||
type TokenizeRequest struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type TokenizeResponse struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Tokenize(ctx context.Context, content string) ([]int, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if status != ServerStatusReady {
|
||||
return nil, fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(TokenizeRequest{Content: content})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("marshaling encode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/tokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("encode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("do encode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read encode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm encode error: %s", body)
|
||||
return nil, fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var encoded TokenizeResponse
|
||||
if err := json.Unmarshal(body, &encoded); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return encoded.Tokens, nil
|
||||
}
|
||||
|
||||
type DetokenizeRequest struct {
|
||||
Tokens []int `json:"tokens"`
|
||||
}
|
||||
|
||||
type DetokenizeResponse struct {
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Detokenize(ctx context.Context, tokens []int) (string, error) {
|
||||
// Make sure the server is ready
|
||||
status, err := s.getServerStatus(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if status != ServerStatusReady {
|
||||
return "", fmt.Errorf("unexpected server status: %d", status)
|
||||
}
|
||||
|
||||
data, err := json.Marshal(DetokenizeRequest{Tokens: tokens})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshaling decode data: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/detokenize", s.port), bytes.NewBuffer(data))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("decode request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("do decode request: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("read decode request: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode >= 400 {
|
||||
log.Printf("llm decode error: %s", body)
|
||||
return "", fmt.Errorf("%s", body)
|
||||
}
|
||||
|
||||
var decoded DetokenizeResponse
|
||||
if err := json.Unmarshal(body, &decoded); err != nil {
|
||||
return "", fmt.Errorf("unmarshal encode response: %w", err)
|
||||
}
|
||||
|
||||
return decoded.Content, nil
|
||||
}
|
||||
|
||||
func (s *LlamaServer) Close() error {
|
||||
if s.cmd != nil {
|
||||
slog.Debug("stopping llama server")
|
||||
return s.cmd.Process.Kill()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
42
llm/status.go
Normal file
42
llm/status.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"os"
|
||||
)
|
||||
|
||||
// StatusWriter is a writer that captures error messages from the llama runner process
|
||||
type StatusWriter struct {
|
||||
LastErrMsg string
|
||||
out *os.File
|
||||
}
|
||||
|
||||
func NewStatusWriter(out *os.File) *StatusWriter {
|
||||
return &StatusWriter{
|
||||
out: out,
|
||||
}
|
||||
}
|
||||
|
||||
// TODO - regex matching to detect errors like
|
||||
// libcublasLt.so.11: cannot open shared object file: No such file or directory
|
||||
|
||||
var errorPrefixes = []string{
|
||||
"error:",
|
||||
"CUDA error",
|
||||
"cudaMalloc failed",
|
||||
"\"ERR\"",
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
var errMsg string
|
||||
for _, prefix := range errorPrefixes {
|
||||
if _, after, ok := bytes.Cut(b, []byte(prefix)); ok {
|
||||
errMsg = prefix + string(bytes.TrimSpace(after))
|
||||
}
|
||||
}
|
||||
if errMsg != "" {
|
||||
w.LastErrMsg = errMsg
|
||||
}
|
||||
|
||||
return w.out.Write(b)
|
||||
}
|
||||
15
llm/utils.go
15
llm/utils.go
@@ -1,15 +0,0 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
func parseDurationMs(ms float64) time.Duration {
|
||||
dur, err := time.ParseDuration(fmt.Sprintf("%fms", ms))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return dur
|
||||
}
|
||||
@@ -142,7 +142,9 @@ func (h *History) Save() error {
|
||||
for cnt := 0; cnt < h.Size(); cnt++ {
|
||||
v, _ := h.Buf.Get(cnt)
|
||||
line, _ := v.([]rune)
|
||||
buf.WriteString(string(line) + "\n")
|
||||
if _, err := buf.WriteString(string(line) + "\n"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
buf.Flush()
|
||||
f.Close()
|
||||
|
||||
@@ -10,7 +10,7 @@ export GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=$V
|
||||
# For developers, you can override the DOCKER_ORG to generate multiarch manifests
|
||||
# DOCKER_ORG=jdoe PUSH=1 ./scripts/build_docker.sh
|
||||
DOCKER_ORG=${DOCKER_ORG:-"ollama"}
|
||||
ARCH_IMAGE_REPO=${ARCH_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
RELEASE_IMAGE_REPO=${RELEASE_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
FINAL_IMAGE_REPO=${FINAL_IMAGE_REPO:-"${DOCKER_ORG}/ollama"}
|
||||
|
||||
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
|
||||
@@ -25,7 +25,7 @@ OLLAMA_SKIP_IMAGE_BUILD=${OLLAMA_SKIP_IMAGE_BUILD:-""}
|
||||
if [ -z "${PUSH}" ] ; then
|
||||
LOAD_OR_PUSH="--load"
|
||||
else
|
||||
echo "Will be pushing ${ARCH_IMAGE_REPO}:$VERSION for ${BUILD_ARCH}"
|
||||
echo "Will be pushing ${RELEASE_IMAGE_REPO}:$VERSION for ${BUILD_ARCH}"
|
||||
LOAD_OR_PUSH="--push"
|
||||
fi
|
||||
|
||||
@@ -37,7 +37,7 @@ if [ -z "${OLLAMA_SKIP_IMAGE_BUILD}" ]; then
|
||||
--build-arg=VERSION \
|
||||
--build-arg=GOFLAGS \
|
||||
-f Dockerfile \
|
||||
-t ${ARCH_IMAGE_REPO}:$VERSION-${TARGETARCH} \
|
||||
-t ${RELEASE_IMAGE_REPO}:$VERSION-${TARGETARCH} \
|
||||
.
|
||||
done
|
||||
|
||||
@@ -49,7 +49,7 @@ if [ -z "${OLLAMA_SKIP_IMAGE_BUILD}" ]; then
|
||||
--build-arg=GOFLAGS \
|
||||
--target runtime-rocm \
|
||||
-f Dockerfile \
|
||||
-t ${ARCH_IMAGE_REPO}:$VERSION-rocm \
|
||||
-t ${RELEASE_IMAGE_REPO}:$VERSION-rocm \
|
||||
.
|
||||
fi
|
||||
fi
|
||||
@@ -57,21 +57,21 @@ fi
|
||||
if [ -z "${OLLAMA_SKIP_MANIFEST_CREATE}" ]; then
|
||||
if [ -n "${PUSH}" ]; then
|
||||
docker manifest create ${FINAL_IMAGE_REPO}:$VERSION \
|
||||
${ARCH_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${ARCH_IMAGE_REPO}:$VERSION-arm64
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-arm64
|
||||
docker manifest push ${FINAL_IMAGE_REPO}:$VERSION
|
||||
|
||||
# For symmetry, tag/push the rocm image
|
||||
if [ "${ARCH_IMAGE_REPO}" != "${FINAL_IMAGE_REPO}" ]; then
|
||||
if [ "${RELEASE_IMAGE_REPO}" != "${FINAL_IMAGE_REPO}" ]; then
|
||||
echo "Tagging and pushing rocm image"
|
||||
docker pull ${ARCH_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${ARCH_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
docker pull ${RELEASE_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${RELEASE_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
docker push ${FINAL_IMAGE_REPO}:$VERSION-rocm
|
||||
fi
|
||||
else
|
||||
echo "Skipping manifest generation when not pushing images are available locally as "
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-amd64"
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-arm64"
|
||||
echo " ${ARCH_IMAGE_REPO}:$VERSION-rocm"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-amd64"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-arm64"
|
||||
echo " ${RELEASE_IMAGE_REPO}:$VERSION-rocm"
|
||||
fi
|
||||
fi
|
||||
|
||||
33
scripts/tag_latest.sh
Executable file
33
scripts/tag_latest.sh
Executable file
@@ -0,0 +1,33 @@
|
||||
#!/bin/sh
|
||||
|
||||
set -eu
|
||||
|
||||
# We use 2 different image repositories to handle combining architecture images into multiarch manifest
|
||||
# (The ROCm image is x86 only and is not a multiarch manifest)
|
||||
# For developers, you can override the DOCKER_ORG to generate multiarch manifests
|
||||
# DOCKER_ORG=jdoe VERSION=0.1.30 PUSH=1 ./scripts/tag_latest.sh
|
||||
DOCKER_ORG=${DOCKER_ORG:-"ollama"}
|
||||
RELEASE_IMAGE_REPO=${RELEASE_IMAGE_REPO:-"${DOCKER_ORG}/release"}
|
||||
FINAL_IMAGE_REPO=${FINAL_IMAGE_REPO:-"${DOCKER_ORG}/ollama"}
|
||||
|
||||
# Set PUSH to a non-empty string to trigger push instead of load
|
||||
PUSH=${PUSH:-""}
|
||||
|
||||
echo "Assembling manifest and tagging latest"
|
||||
docker manifest rm ${FINAL_IMAGE_REPO}:latest || true
|
||||
docker manifest create ${FINAL_IMAGE_REPO}:latest \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-amd64 \
|
||||
${RELEASE_IMAGE_REPO}:$VERSION-arm64
|
||||
|
||||
docker pull ${RELEASE_IMAGE_REPO}:$VERSION-rocm
|
||||
docker tag ${RELEASE_IMAGE_REPO}:$VERSION-rocm ${FINAL_IMAGE_REPO}:rocm
|
||||
|
||||
if [ -n "${PUSH}" ]; then
|
||||
echo "Pushing latest tags up..."
|
||||
docker manifest push ${FINAL_IMAGE_REPO}:latest
|
||||
docker push ${FINAL_IMAGE_REPO}:rocm
|
||||
else
|
||||
echo "Not pushing ${FINAL_IMAGE_REPO}:latest and ${FINAL_IMAGE_REPO}:rocm"
|
||||
fi
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/convert"
|
||||
"github.com/ollama/ollama/format"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/version"
|
||||
@@ -321,7 +322,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
pathName := realpath(modelFileDir, c.Args)
|
||||
|
||||
ggufName, err := convertSafetensors(name, pathName)
|
||||
ggufName, err := convertSafetensors(name, pathName, fn)
|
||||
if err != nil {
|
||||
var pathErr *fs.PathError
|
||||
switch {
|
||||
@@ -336,6 +337,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
if ggufName != "" {
|
||||
pathName = ggufName
|
||||
slog.Debug(fmt.Sprintf("new image layer path: %s", pathName))
|
||||
defer os.RemoveAll(ggufName)
|
||||
}
|
||||
|
||||
@@ -419,34 +421,32 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
defer bin.Close()
|
||||
|
||||
var offset int64
|
||||
CREATE:
|
||||
for {
|
||||
fn(api.ProgressResponse{Status: "creating model layer"})
|
||||
if _, err := bin.Seek(offset, io.SeekStart); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bin.Seek(offset, io.SeekStart)
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
break CREATE
|
||||
case errors.Is(err, llm.ErrUnsupportedFormat):
|
||||
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
|
||||
default:
|
||||
return err
|
||||
}
|
||||
ggml, size, err := llm.DecodeGGML(bin)
|
||||
if errors.Is(err, io.EOF) {
|
||||
break
|
||||
} else if errors.Is(err, llm.ErrUnsupportedFormat) {
|
||||
return fmt.Errorf("model binary specified in FROM field is not a valid gguf format model, %w", err)
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
config.SetModelFormat(ggml.Name())
|
||||
config.SetModelFamily(ggml.ModelFamily())
|
||||
config.SetModelType(ggml.ModelType())
|
||||
config.SetFileType(ggml.FileType())
|
||||
config.SetModelFamily(ggml.KV().Architecture())
|
||||
config.SetModelType(format.HumanNumber(ggml.KV().ParameterCount()))
|
||||
config.SetFileType(ggml.KV().FileType())
|
||||
|
||||
mediatype := mediatype
|
||||
if ggml.ModelFamily() == "clip" {
|
||||
if ggml.KV().Architecture() == "clip" {
|
||||
mediatype = "application/vnd.ollama.image.projector"
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(bin, offset, ggml.Size)
|
||||
sr := io.NewSectionReader(bin, offset, size)
|
||||
layer, err := NewLayer(sr, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -454,7 +454,7 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
|
||||
layers.Add(layer)
|
||||
|
||||
offset += ggml.Size
|
||||
offset += size
|
||||
}
|
||||
case "adapter":
|
||||
if strings.HasPrefix(c.Args, "@") {
|
||||
@@ -473,12 +473,12 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
defer bin.Close()
|
||||
|
||||
ggml, err := llm.DecodeGGML(bin)
|
||||
_, size, err := llm.DecodeGGML(bin)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sr := io.NewSectionReader(bin, 0, ggml.Size)
|
||||
sr := io.NewSectionReader(bin, 0, size)
|
||||
layer, err := NewLayer(sr, mediatype)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -550,13 +550,6 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
}
|
||||
}
|
||||
|
||||
// xxx - can this be removed?
|
||||
if config.ModelType == "65B" {
|
||||
if gqa, ok := formattedParams["gqa"].(int); ok && gqa == 8 {
|
||||
config.ModelType = "70B"
|
||||
}
|
||||
}
|
||||
|
||||
var b bytes.Buffer
|
||||
if err := json.NewEncoder(&b).Encode(formattedParams); err != nil {
|
||||
return err
|
||||
@@ -621,8 +614,8 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
|
||||
return nil
|
||||
}
|
||||
|
||||
func convertSafetensors(name, fn string) (string, error) {
|
||||
r, err := zip.OpenReader(fn)
|
||||
func convertSafetensors(name, path string, fn func(resp api.ProgressResponse)) (string, error) {
|
||||
r, err := zip.OpenReader(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -634,6 +627,7 @@ func convertSafetensors(name, fn string) (string, error) {
|
||||
}
|
||||
defer os.RemoveAll(tempDir)
|
||||
|
||||
fn(api.ProgressResponse{Status: "unpacking model metadata"})
|
||||
for _, f := range r.File {
|
||||
fpath := filepath.Join(tempDir, f.Name)
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
@@ -660,32 +654,27 @@ func convertSafetensors(name, fn string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
SupportedArchs := []string{
|
||||
"MistralForCausalLM",
|
||||
}
|
||||
|
||||
for _, arch := range params.Architectures {
|
||||
if !slices.Contains(SupportedArchs, arch) {
|
||||
return "", fmt.Errorf("this safetensors model is not yet supported")
|
||||
}
|
||||
}
|
||||
|
||||
t, err := convert.GetSafeTensors(tempDir)
|
||||
mArch, err := convert.GetModelArchFromParams(name, tempDir, params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
vocab, err := convert.LoadTokens(tempDir)
|
||||
fn(api.ProgressResponse{Status: "processing safetensors"})
|
||||
if err := mArch.GetTensors(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if err := mArch.LoadVocab(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn(api.ProgressResponse{Status: "converting model"})
|
||||
path, err = mArch.WriteGGUF()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fn, err = convert.WriteGGUF(name, t, params, vocab)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return fn, nil
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func CopyModel(src, dest string) error {
|
||||
|
||||
@@ -56,12 +56,13 @@ func init() {
|
||||
var loaded struct {
|
||||
mu sync.Mutex
|
||||
|
||||
runner llm.LLM
|
||||
llama *llm.LlamaServer
|
||||
|
||||
expireAt time.Time
|
||||
expireTimer *time.Timer
|
||||
|
||||
*Model
|
||||
model string
|
||||
adapters []string
|
||||
projectors []string
|
||||
*api.Options
|
||||
}
|
||||
|
||||
@@ -69,21 +70,28 @@ var defaultSessionDuration = 5 * time.Minute
|
||||
|
||||
// load a model into memory if it is not already loaded, it is up to the caller to lock loaded.mu before calling this function
|
||||
func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.Duration) error {
|
||||
needLoad := loaded.runner == nil || // is there a model loaded?
|
||||
loaded.ModelPath != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.AdapterPaths, model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) // have the runner options changed?
|
||||
ctx, cancel := context.WithTimeout(c, 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
needLoad := loaded.llama == nil || // is there a model loaded?
|
||||
loaded.model != model.ModelPath || // has the base model changed?
|
||||
!reflect.DeepEqual(loaded.adapters, model.AdapterPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.projectors, model.ProjectorPaths) || // have the adapters changed?
|
||||
!reflect.DeepEqual(loaded.Options.Runner, opts.Runner) || // have the runner options changed?
|
||||
loaded.llama.Ping(ctx) != nil
|
||||
|
||||
if needLoad {
|
||||
if loaded.runner != nil {
|
||||
if loaded.llama != nil {
|
||||
slog.Info("changing loaded model")
|
||||
loaded.runner.Close()
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.llama.Close()
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
}
|
||||
|
||||
llmRunner, err := llm.New(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
llama, err := llm.NewLlamaServer(model.ModelPath, model.AdapterPaths, model.ProjectorPaths, opts)
|
||||
if err != nil {
|
||||
// some older models are not compatible with newer versions of llama.cpp
|
||||
// show a generalized compatibility error until there is a better way to
|
||||
@@ -95,28 +103,26 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
|
||||
return err
|
||||
}
|
||||
|
||||
loaded.Model = model
|
||||
loaded.runner = llmRunner
|
||||
loaded.model = model.ModelPath
|
||||
loaded.adapters = model.AdapterPaths
|
||||
loaded.projectors = model.ProjectorPaths
|
||||
loaded.llama = llama
|
||||
loaded.Options = &opts
|
||||
}
|
||||
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
|
||||
if loaded.expireTimer == nil {
|
||||
loaded.expireTimer = time.AfterFunc(sessionDuration, func() {
|
||||
loaded.mu.Lock()
|
||||
defer loaded.mu.Unlock()
|
||||
|
||||
if time.Now().Before(loaded.expireAt) {
|
||||
return
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
}
|
||||
|
||||
loaded.runner = nil
|
||||
loaded.Model = nil
|
||||
loaded.llama = nil
|
||||
loaded.model = ""
|
||||
loaded.adapters = nil
|
||||
loaded.projectors = nil
|
||||
loaded.Options = nil
|
||||
})
|
||||
}
|
||||
@@ -265,7 +271,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
|
||||
sb.Reset()
|
||||
if req.Context != nil {
|
||||
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
|
||||
prev, err := loaded.llama.Detokenize(c.Request.Context(), req.Context)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
|
||||
return
|
||||
@@ -286,9 +292,8 @@ func GenerateHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
// Build up the full response
|
||||
@@ -322,7 +327,7 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// TODO (jmorganca): encode() should not strip special tokens
|
||||
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
|
||||
tokens, err := loaded.llama.Tokenize(c.Request.Context(), p)
|
||||
if err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
return
|
||||
@@ -344,13 +349,13 @@ func GenerateHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
req := llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
if err := loaded.llama.Completion(c.Request.Context(), req, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
@@ -471,7 +476,7 @@ func EmbeddingsHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
|
||||
embedding, err := loaded.llama.Embedding(c.Request.Context(), req.Prompt)
|
||||
if err != nil {
|
||||
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
|
||||
@@ -1013,16 +1018,14 @@ func allowedHostsMiddleware(addr net.Addr) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
func (s *Server) GenerateRoutes() http.Handler {
|
||||
var origins []string
|
||||
if o := os.Getenv("OLLAMA_ORIGINS"); o != "" {
|
||||
origins = strings.Split(o, ",")
|
||||
}
|
||||
|
||||
config := cors.DefaultConfig()
|
||||
config.AllowWildcard = true
|
||||
config.AllowBrowserExtensions = true
|
||||
|
||||
config.AllowOrigins = origins
|
||||
if allowedOrigins := strings.Trim(os.Getenv("OLLAMA_ORIGINS"), "\"'"); allowedOrigins != "" {
|
||||
config.AllowOrigins = strings.Split(allowedOrigins, ",")
|
||||
}
|
||||
|
||||
for _, allowOrigin := range defaultAllowOrigins {
|
||||
config.AllowOrigins = append(config.AllowOrigins,
|
||||
fmt.Sprintf("http://%s", allowOrigin),
|
||||
@@ -1125,8 +1128,8 @@ func Serve(ln net.Listener) error {
|
||||
signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-signals
|
||||
if loaded.runner != nil {
|
||||
loaded.runner.Close()
|
||||
if loaded.llama != nil {
|
||||
loaded.llama.Close()
|
||||
}
|
||||
gpu.Cleanup()
|
||||
os.Exit(0)
|
||||
@@ -1198,7 +1201,7 @@ func streamResponse(c *gin.Context, ch chan any) {
|
||||
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
|
||||
func chatPrompt(ctx context.Context, template string, messages []api.Message, numCtx int) (string, error) {
|
||||
encode := func(s string) ([]int, error) {
|
||||
return loaded.runner.Encode(ctx, s)
|
||||
return loaded.llama.Tokenize(ctx, s)
|
||||
}
|
||||
|
||||
prompt, err := ChatPrompt(template, messages, numCtx, encode)
|
||||
@@ -1328,9 +1331,8 @@ func ChatHandler(c *gin.Context) {
|
||||
go func() {
|
||||
defer close(ch)
|
||||
|
||||
fn := func(r llm.PredictResult) {
|
||||
fn := func(r llm.CompletionResponse) {
|
||||
// Update model expiration
|
||||
loaded.expireAt = time.Now().Add(sessionDuration)
|
||||
loaded.expireTimer.Reset(sessionDuration)
|
||||
|
||||
resp := api.ChatResponse{
|
||||
@@ -1354,14 +1356,12 @@ func ChatHandler(c *gin.Context) {
|
||||
ch <- resp
|
||||
}
|
||||
|
||||
// Start prediction
|
||||
predictReq := llm.PredictOpts{
|
||||
if err := loaded.llama.Completion(c.Request.Context(), llm.CompletionRequest{
|
||||
Prompt: prompt,
|
||||
Format: req.Format,
|
||||
Images: images,
|
||||
Options: opts,
|
||||
}
|
||||
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
|
||||
}, fn); err != nil {
|
||||
ch <- gin.H{"error": err.Error()}
|
||||
}
|
||||
}()
|
||||
|
||||
@@ -3,6 +3,7 @@ package server
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -16,7 +17,6 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/ollama/ollama/api"
|
||||
"github.com/ollama/ollama/llm"
|
||||
"github.com/ollama/ollama/parser"
|
||||
"github.com/ollama/ollama/version"
|
||||
)
|
||||
@@ -31,13 +31,22 @@ func Test_Routes(t *testing.T) {
|
||||
}
|
||||
|
||||
createTestFile := func(t *testing.T, name string) string {
|
||||
t.Helper()
|
||||
|
||||
f, err := os.CreateTemp(t.TempDir(), name)
|
||||
assert.Nil(t, err)
|
||||
defer f.Close()
|
||||
|
||||
_, err = f.Write([]byte("GGUF"))
|
||||
err = binary.Write(f, binary.LittleEndian, []byte("GGUF"))
|
||||
assert.Nil(t, err)
|
||||
_, err = f.Write([]byte{0x2, 0})
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint32(3))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
assert.Nil(t, err)
|
||||
|
||||
err = binary.Write(f, binary.LittleEndian, uint64(0))
|
||||
assert.Nil(t, err)
|
||||
|
||||
return f.Name()
|
||||
@@ -201,7 +210,7 @@ func Test_Routes(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
s := Server{}
|
||||
s := &Server{}
|
||||
router := s.GenerateRoutes()
|
||||
|
||||
httpSrv := httptest.NewServer(router)
|
||||
@@ -232,27 +241,3 @@ func Test_Routes(t *testing.T) {
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
type MockLLM struct {
|
||||
encoding []int
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
|
||||
return llm.encoding, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
|
||||
return []float64{}, nil
|
||||
}
|
||||
|
||||
func (llm *MockLLM) Close() {
|
||||
// do nothing
|
||||
}
|
||||
|
||||
113
x/api/api.go
Normal file
113
x/api/api.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/oweb"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
regtype "github.com/ollama/ollama/x/registry/apitype"
|
||||
)
|
||||
|
||||
// Common API Errors
|
||||
var (
|
||||
errUnqualifiedRef = oweb.Invalid("invalid", "name", "must be fully qualified")
|
||||
errRefNotFound = oweb.Invalid("not_found", "name", "no such model")
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
Build *build.Server
|
||||
}
|
||||
|
||||
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
oweb.Serve(s.serveHTTP, w, r)
|
||||
}
|
||||
|
||||
func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error {
|
||||
switch r.URL.Path {
|
||||
case "/v1/push":
|
||||
return s.handlePush(w, r)
|
||||
default:
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
}
|
||||
|
||||
func want(r *http.Request, method, path string) bool {
|
||||
return r.Method == method && r.URL.Path == path
|
||||
}
|
||||
|
||||
func (s *Server) handlePush(_ http.ResponseWriter, r *http.Request) error {
|
||||
if r.Method != "POST" {
|
||||
return oweb.ErrMethodNotAllowed
|
||||
}
|
||||
|
||||
params, err := oweb.DecodeJSON[apitype.PushRequest](r.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if params.Name == "" {
|
||||
return oweb.Missing("name")
|
||||
}
|
||||
|
||||
const registryURLTODO = "http://localhost:8888"
|
||||
|
||||
man, err := s.Build.ManifestData(params.Name)
|
||||
if err != nil {
|
||||
if errors.Is(err, build.ErrNotFound) {
|
||||
return errRefNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
c := registry.Client{BaseURL: registryURLTODO}
|
||||
requirements, err := c.Push(r.Context(), params.Name, man, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var uploads []regtype.CompletePart
|
||||
for _, rq := range requirements {
|
||||
l, err := s.Build.LayerFile(rq.Digest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = func() error {
|
||||
f, err := os.Open(l)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
cp, err := registry.PushLayer(r.Context(), f, rq.URL, rq.Offset, rq.Size)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
uploads = append(uploads, cp)
|
||||
return nil
|
||||
}()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// commit the manifest to the registry
|
||||
requirements, err = c.Push(r.Context(), params.Name, man, ®istry.PushParams{
|
||||
CompleteParts: uploads,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range requirements {
|
||||
err = errors.Join(err, fmt.Errorf("push failed for %q", r.Digest))
|
||||
}
|
||||
return err
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) handlePull(w http.ResponseWriter, r *http.Request) error {
|
||||
return oweb.ErrNotFound
|
||||
}
|
||||
209
x/build/build.go
Normal file
209
x/build/build.go
Normal file
@@ -0,0 +1,209 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrIncompleteRef = errors.New("unqualified ref")
|
||||
ErrBuildPresentInRef = errors.New("build present in ref")
|
||||
ErrUnsupportedModelFormat = errors.New("unsupported model format")
|
||||
ErrMissingFileType = errors.New("missing 'general.file_type' key")
|
||||
ErrNotFound = errors.New("not found")
|
||||
)
|
||||
|
||||
type mediaType string
|
||||
|
||||
// Known media types
|
||||
const (
|
||||
mediaTypeModel mediaType = "application/vnd.ollama.image.model"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
st *blobstore.Store
|
||||
}
|
||||
|
||||
// Open starts a new build server that uses dir as the base directory for all
|
||||
// build artifacts. If dir is empty, DefaultDir is used.
|
||||
//
|
||||
// It returns an error if the provided or default dir cannot be initialized.
|
||||
func Open(dir string) (*Server, error) {
|
||||
if dir == "" {
|
||||
var err error
|
||||
dir, err = DefaultDir()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
st, err := blobstore.Open(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Server{st: st}, nil
|
||||
}
|
||||
|
||||
func (s *Server) Build(ref string, f model.File) error {
|
||||
mp := model.ParseName(ref)
|
||||
if !mp.IsCompleteNoBuild() {
|
||||
return fmt.Errorf("%w: %q", ErrIncompleteRef, ref)
|
||||
}
|
||||
|
||||
// 1. Resolve FROM
|
||||
// a. If it's a local file (gguf), hash it and add it to the store.
|
||||
// c. If it's a remote file (http), refuse.
|
||||
// 2. Turn other pragmas into layers, and add them to the store.
|
||||
// 3. Create a manifest from the layers.
|
||||
// 4. Store the manifest in the manifest cache
|
||||
// 5. Done.
|
||||
|
||||
if f.From == "" {
|
||||
return &model.FileError{Pragma: "FROM", Message: "missing"}
|
||||
}
|
||||
|
||||
var layers []layerJSON
|
||||
|
||||
id, info, size, err := s.importModel(f.From)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: mediaTypeModel,
|
||||
Size: size,
|
||||
})
|
||||
|
||||
id, size, err = blobstore.PutString(s.st, f.License)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
layers = append(layers, layerJSON{
|
||||
ID: id,
|
||||
MediaType: "text/plain",
|
||||
Size: size,
|
||||
})
|
||||
|
||||
data, err := json.Marshal(manifestJSON{Layers: layers})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.setManifestData(
|
||||
mp.WithBuild(info.FileType.String()),
|
||||
data,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *Server) LayerFile(digest string) (string, error) {
|
||||
fileName := s.st.OutputFilename(blobstore.ParseID(digest))
|
||||
_, err := os.Stat(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return "", fmt.Errorf("%w: %q", ErrNotFound, digest)
|
||||
}
|
||||
return fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) ManifestData(ref string) ([]byte, error) {
|
||||
data, _, err := s.resolve(model.ParseName(ref))
|
||||
return data, err
|
||||
}
|
||||
|
||||
// WeightFile returns the absolute path to the weights file for the given model ref.
|
||||
func (s *Server) WeightsFile(ref string) (string, error) {
|
||||
m, err := s.getManifest(model.ParseName(ref))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, l := range m.Layers {
|
||||
if l.MediaType == mediaTypeModel {
|
||||
return s.st.OutputFilename(l.ID), nil
|
||||
}
|
||||
}
|
||||
return "", fmt.Errorf("missing weights layer for %q", ref)
|
||||
}
|
||||
|
||||
// resolve returns the data for the given ref, if any.
|
||||
//
|
||||
// TODO: This should ideally return an ID, but the current on
|
||||
// disk layout is that the actual manifest is stored in the "ref" instead of
|
||||
// a pointer to a content-addressed blob. I (bmizerany) think we should
|
||||
// change the on-disk layout to store the manifest in a content-addressed
|
||||
// blob, and then have the ref point to that blob. This would simplify the
|
||||
// code, allow us to have integrity checks on the manifest, and clean up
|
||||
// this interface.
|
||||
func (s *Server) resolve(ref model.Name) (data []byte, fileName string, err error) {
|
||||
fileName, err = s.refFileName(ref)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
data, err = os.ReadFile(fileName)
|
||||
if errors.Is(err, fs.ErrNotExist) {
|
||||
return nil, "", fmt.Errorf("%w: %q", ErrNotFound, ref)
|
||||
}
|
||||
if err != nil {
|
||||
// do not wrap the error here, as it is likely an I/O error
|
||||
// and we want to preserve the absraction since we may not
|
||||
// be on disk later.
|
||||
return nil, "", fmt.Errorf("manifest read error: %v", err)
|
||||
}
|
||||
return data, fileName, nil
|
||||
}
|
||||
|
||||
func (s *Server) SetManifestData(ref string, data []byte) error {
|
||||
return s.setManifestData(model.ParseName(ref), data)
|
||||
}
|
||||
|
||||
// Set sets the data for the given ref.
|
||||
func (s *Server) setManifestData(mp model.Name, data []byte) error {
|
||||
path, err := s.refFileName(mp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0777); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(path, data, 0666); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) refFileName(mp model.Name) (string, error) {
|
||||
if !mp.IsComplete() {
|
||||
return "", fmt.Errorf("ref not fully qualified: %q", mp)
|
||||
}
|
||||
return filepath.Join(s.st.Dir(), "manifests", filepath.Join(mp.Parts()...)), nil
|
||||
}
|
||||
|
||||
type manifestJSON struct {
|
||||
// Layers is the list of layers in the manifest.
|
||||
Layers []layerJSON `json:"layers"`
|
||||
}
|
||||
|
||||
// Layer is a layer in a model manifest.
|
||||
type layerJSON struct {
|
||||
// ID is the ID of the layer.
|
||||
ID blobstore.ID `json:"digest"`
|
||||
MediaType mediaType `json:"mediaType"`
|
||||
Size int64 `json:"size"`
|
||||
}
|
||||
|
||||
func (s *Server) getManifest(ref model.Name) (manifestJSON, error) {
|
||||
data, path, err := s.resolve(ref)
|
||||
if err != nil {
|
||||
return manifestJSON{}, err
|
||||
}
|
||||
var m manifestJSON
|
||||
if err := json.Unmarshal(data, &m); err != nil {
|
||||
return manifestJSON{}, &fs.PathError{Op: "unmarshal", Path: path, Err: err}
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
163
x/build/build_test.go
Normal file
163
x/build/build_test.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
"github.com/ollama/ollama/x/model"
|
||||
)
|
||||
|
||||
const qualifiedRef = "x/y/z:latest+Q4_0"
|
||||
|
||||
func TestServerBuildErrors(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("unqualified ref", func(t *testing.T) {
|
||||
err := s.Build("x", model.File{})
|
||||
if !errors.Is(err, ErrIncompleteRef) {
|
||||
t.Fatalf("Build() err = %v; want unqualified ref", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM pragma missing", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{})
|
||||
var e *model.FileError
|
||||
if !errors.As(err, &e) {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if e.Pragma != "FROM" {
|
||||
t.Errorf("e.Pragma = %s; want FROM", e.Pragma)
|
||||
}
|
||||
if e.Message != "missing" {
|
||||
t.Errorf("e.Message = %s; want missing", e.Message)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM file not found", func(t *testing.T) {
|
||||
err := s.Build(qualifiedRef, model.File{From: "bar"})
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Build() err = %v; want file not found", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM gguf", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
// Write a gguf file without general.file_type metadata.
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"",
|
||||
)
|
||||
|
||||
err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")})
|
||||
if !errors.Is(err, ErrMissingFileType) {
|
||||
t.Fatalf("Build() err = %#v; want missing file type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM obscure dir", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.mkdirAll("unknown")
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("unknown")}); err != ErrUnsupportedModelFormat {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("FROM unsupported model type", func(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
from := w.write("unknown", "unknown content")
|
||||
err := s.Build(qualifiedRef, model.File{From: from})
|
||||
if !errors.Is(err, ErrUnsupportedModelFormat) {
|
||||
t.Fatalf("Build() err = %#v; want unsupported model type", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestBuildBasicGGUF(t *testing.T) {
|
||||
w := newWorkDir(t)
|
||||
w.write("gguf", ""+
|
||||
"GGUF"+ // magic
|
||||
"\x03\x00\x00\x00"+ // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00"+ // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00"+ // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00"+ // key length
|
||||
"general.file_type"+ // key
|
||||
"\x04\x00\x00\x00"+ // type (uint32)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00"+ // uint32 value
|
||||
"",
|
||||
)
|
||||
|
||||
dir := t.TempDir()
|
||||
s, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := s.Build(qualifiedRef, model.File{From: w.fileName("gguf")}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
filepath.Walk(dir, func(p string, info os.FileInfo, err error) error {
|
||||
t.Logf("file: %s", p)
|
||||
return nil
|
||||
})
|
||||
|
||||
_, err = s.WeightsFile("unknown/y/z:latest+Q4_0")
|
||||
if !errors.Is(err, ErrNotFound) {
|
||||
t.Fatalf("WeightsFile() err = %v; want not found", err)
|
||||
}
|
||||
|
||||
path, err := s.WeightsFile("x/y/z:latest+Q4_0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
info, err := gguf.Stat(path)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if info.FileType != gguf.TypeQ4_0 {
|
||||
t.Errorf("info.FileType = %d; want 1", info.FileType)
|
||||
}
|
||||
}
|
||||
|
||||
type work struct {
|
||||
t testing.TB
|
||||
dir string
|
||||
}
|
||||
|
||||
func newWorkDir(t *testing.T) work {
|
||||
return work{t: t, dir: t.TempDir()}
|
||||
}
|
||||
|
||||
func (w work) write(name, content string) (path string) {
|
||||
w.t.Helper()
|
||||
path = w.fileName(name)
|
||||
if err := os.WriteFile(path, []byte(content), 0644); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (w work) fileName(name string) string {
|
||||
w.t.Helper()
|
||||
return filepath.Join(w.dir, name)
|
||||
}
|
||||
|
||||
func (w work) mkdirAll(path string) {
|
||||
w.t.Helper()
|
||||
if err := os.MkdirAll(filepath.Join(w.dir, path), 0755); err != nil {
|
||||
w.t.Fatal(err)
|
||||
}
|
||||
}
|
||||
12
x/build/convert.go
Normal file
12
x/build/convert.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package build
|
||||
|
||||
func convertSafeTensorToGGUF(path string) (ggufPath string, err error) {
|
||||
// TODO: decine on hueristic for converting safetensor to gguf and
|
||||
// the errors that can be returned. For now, we just say
|
||||
// "unsupported", however it may be intended to be a valid safe
|
||||
// tensor but we hit an error in the conversion.
|
||||
//
|
||||
// I (bmizernay) think this will naturally evolve as we implement
|
||||
// the conversion.
|
||||
return "", ErrUnsupportedModelFormat
|
||||
}
|
||||
28
x/build/default.go
Normal file
28
x/build/default.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultDir = sync.OnceValues(func() (string, error) {
|
||||
dir := os.Getenv("OLLAMA_MODELS")
|
||||
if dir == "" {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dir = filepath.Join(home, ".ollama", "models")
|
||||
}
|
||||
return dir, nil
|
||||
})
|
||||
)
|
||||
|
||||
// DefaultDir returns the default directory for models. It returns the value
|
||||
// of the OLLAMA_MODELS environment variable if set; otherwise it returns
|
||||
// "$HOME/.ollama/models".
|
||||
func DefaultDir() (string, error) {
|
||||
return defaultDir()
|
||||
}
|
||||
59
x/build/import.go
Normal file
59
x/build/import.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package build
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/build/internal/blobstore"
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func importError(err error) (blobstore.ID, gguf.Info, int64, error) {
|
||||
return blobstore.ID{}, gguf.Info{}, 0, err
|
||||
}
|
||||
|
||||
func (s *Server) importModel(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
if info.IsDir() {
|
||||
return s.importSafeTensor(path)
|
||||
} else {
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) importGGUF(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
info, err := gguf.StatReader(f)
|
||||
if errors.Is(err, gguf.ErrBadMagic) {
|
||||
return importError(ErrUnsupportedModelFormat)
|
||||
}
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
|
||||
if info.FileType == 0 {
|
||||
return importError(fmt.Errorf("%w: %q", ErrMissingFileType, path))
|
||||
}
|
||||
id, size, err := s.st.Put(f)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return id, info, size, nil
|
||||
}
|
||||
|
||||
func (s *Server) importSafeTensor(path string) (_ blobstore.ID, _ gguf.Info, size int64, _ error) {
|
||||
path, err := convertSafeTensorToGGUF(path)
|
||||
if err != nil {
|
||||
return importError(err)
|
||||
}
|
||||
return s.importGGUF(path)
|
||||
}
|
||||
329
x/build/internal/blobstore/blob.go
Normal file
329
x/build/internal/blobstore/blob.go
Normal file
@@ -0,0 +1,329 @@
|
||||
// Package blobstore implements a blob store.
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrInvalidID = errors.New("invalid ID")
|
||||
)
|
||||
|
||||
const HashSize = 32
|
||||
|
||||
// An ID is a blob output key, the hash of an output of a computation.
|
||||
type ID struct {
|
||||
a [HashSize]byte
|
||||
}
|
||||
|
||||
func (id ID) MarshalText() ([]byte, error) {
|
||||
return []byte(id.String()), nil
|
||||
}
|
||||
|
||||
func (id *ID) UnmarshalText(text []byte) error {
|
||||
*id = ParseID(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
func ParseID(s string) ID {
|
||||
const prefix = "sha256-"
|
||||
h, ok := strings.CutPrefix(s, prefix)
|
||||
if !ok {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
if len(h) != HashSize*2 {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var b []byte
|
||||
_, err := fmt.Sscanf(h, "%x", &b)
|
||||
if err != nil {
|
||||
return ID{}
|
||||
}
|
||||
|
||||
var id ID
|
||||
copy(id.a[:], b)
|
||||
return id
|
||||
}
|
||||
|
||||
func (id ID) String() string {
|
||||
if !id.Valid() {
|
||||
return ""
|
||||
}
|
||||
return fmt.Sprintf("sha256-%x", id.a[:])
|
||||
}
|
||||
|
||||
func (id ID) Valid() bool {
|
||||
return id != ID{}
|
||||
}
|
||||
|
||||
func (id ID) Match(h [HashSize]byte) bool {
|
||||
return id.a == h
|
||||
}
|
||||
|
||||
// A Store is a blob store, backed by a file system directory tree.
|
||||
type Store struct {
|
||||
dir string
|
||||
now func() time.Time
|
||||
}
|
||||
|
||||
// Open opens and returns the store in the given directory.
|
||||
//
|
||||
// It is safe for multiple processes on a single machine to use the
|
||||
// same store directory in a local file system simultaneously.
|
||||
// They will coordinate using operating system file locks and may
|
||||
// duplicate effort but will not corrupt the store.
|
||||
//
|
||||
// However, it is NOT safe for multiple processes on different machines
|
||||
// to share a store directory (for example, if the directory were stored
|
||||
// in a network file system). File locking is notoriously unreliable in
|
||||
// network file systems and may not suffice to protect the store.
|
||||
func Open(dir string) (*Store, error) {
|
||||
info, err := os.Stat(dir)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
return nil, &fs.PathError{Op: "open", Path: dir, Err: fmt.Errorf("not a directory")}
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Join(dir, "blobs"), 0777); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c := &Store{
|
||||
dir: dir,
|
||||
now: time.Now,
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (s *Store) Dir() string {
|
||||
return s.dir
|
||||
}
|
||||
|
||||
// fileName returns the name of the blob file corresponding to the given id.
|
||||
func (s *Store) fileName(id ID) string {
|
||||
return filepath.Join(s.dir, "blobs", fmt.Sprintf("sha256-%x", id.a[:]))
|
||||
}
|
||||
|
||||
// An entryNotFoundError indicates that a store entry was not found, with an
|
||||
// optional underlying reason.
|
||||
type entryNotFoundError struct {
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Error() string {
|
||||
if e.Err == nil {
|
||||
return "store entry not found"
|
||||
}
|
||||
return fmt.Sprintf("store entry not found: %v", e.Err)
|
||||
}
|
||||
|
||||
func (e *entryNotFoundError) Unwrap() error {
|
||||
return e.Err
|
||||
}
|
||||
|
||||
type Entry struct {
|
||||
_ structs.Incomparable
|
||||
|
||||
ID ID
|
||||
Size int64
|
||||
Time time.Time // when added to store
|
||||
}
|
||||
|
||||
// GetFile looks up the blob ID in the store and returns
|
||||
// the name of the corresponding data file.
|
||||
func GetFile(s *Store, id ID) (file string, entry Entry, err error) {
|
||||
entry, err = s.Get(id)
|
||||
if err != nil {
|
||||
return "", Entry{}, err
|
||||
}
|
||||
file = s.OutputFilename(entry.ID)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return "", Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
if info.Size() != entry.Size {
|
||||
return "", Entry{}, &entryNotFoundError{Err: errors.New("file incomplete")}
|
||||
}
|
||||
return file, entry, nil
|
||||
}
|
||||
|
||||
// GetBytes looks up the blob ID in the store and returns
|
||||
// the corresponding output bytes.
|
||||
// GetBytes should only be used for data that can be expected to fit in memory.
|
||||
func GetBytes(s *Store, id ID) ([]byte, Entry, error) {
|
||||
entry, err := s.Get(id)
|
||||
if err != nil {
|
||||
return nil, entry, err
|
||||
}
|
||||
data, _ := os.ReadFile(s.OutputFilename(entry.ID))
|
||||
if entry.ID.Match(sha256.Sum256(data)) {
|
||||
return nil, entry, &entryNotFoundError{Err: errors.New("bad checksum")}
|
||||
}
|
||||
return data, entry, nil
|
||||
}
|
||||
|
||||
// OutputFilename returns the name of the blob file for the given ID.
|
||||
func (s *Store) OutputFilename(id ID) string {
|
||||
file := s.fileName(id)
|
||||
// TODO(bmizerany): touch as "used" for cache trimming. (see
|
||||
// cache.go in cmd/go/internal/cache for the full reference implementation to go off of.
|
||||
return file
|
||||
}
|
||||
|
||||
// Get looks up the blob ID in the store,
|
||||
// returning the corresponding output ID and file size, if any.
|
||||
// Note that finding an output ID does not guarantee that the
|
||||
// saved file for that output ID is still available.
|
||||
func (s *Store) Get(id ID) (Entry, error) {
|
||||
file := s.fileName(id)
|
||||
info, err := os.Stat(file)
|
||||
if err != nil {
|
||||
return Entry{}, &entryNotFoundError{Err: err}
|
||||
}
|
||||
return Entry{
|
||||
ID: id,
|
||||
Size: info.Size(),
|
||||
Time: info.ModTime(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Store) Close() error {
|
||||
// TODO(bmizerany): return c.Trim()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put stores the data read from the given file into the store as ID.
|
||||
//
|
||||
// It may read file twice. The content of file must not change between the
|
||||
// two passes.
|
||||
func (s *Store) Put(file io.ReadSeeker) (ID, int64, error) {
|
||||
return s.put(file)
|
||||
}
|
||||
|
||||
func PutBytes(s *Store, data []byte) (ID, int64, error) {
|
||||
return s.Put(bytes.NewReader(data))
|
||||
}
|
||||
|
||||
func PutString(s *Store, data string) (ID, int64, error) {
|
||||
return s.Put(strings.NewReader(data))
|
||||
}
|
||||
|
||||
func (s *Store) put(file io.ReadSeeker) (ID, int64, error) {
|
||||
// Compute output ID.
|
||||
h := sha256.New()
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
size, err := io.Copy(h, file)
|
||||
if err != nil {
|
||||
return ID{}, 0, err
|
||||
}
|
||||
var out ID
|
||||
h.Sum(out.a[:0])
|
||||
|
||||
// Copy to blob file (if not already present).
|
||||
if err := s.copyFile(file, out, size); err != nil {
|
||||
return out, size, err
|
||||
}
|
||||
|
||||
// TODO: Add to manifest index.
|
||||
return out, size, nil
|
||||
}
|
||||
|
||||
// copyFile copies file into the store, expecting it to have the given
|
||||
// output ID and size, if that file is not present already.
|
||||
func (s *Store) copyFile(file io.ReadSeeker, out ID, size int64) error {
|
||||
name := s.fileName(out)
|
||||
println("name", name)
|
||||
info, err := os.Stat(name)
|
||||
if err == nil && info.Size() == size {
|
||||
// Check hash.
|
||||
if f, err := os.Open(name); err == nil {
|
||||
h := sha256.New()
|
||||
io.Copy(h, f)
|
||||
f.Close()
|
||||
var out2 ID
|
||||
h.Sum(out2.a[:0])
|
||||
if out == out2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
// Hash did not match. Fall through and rewrite file.
|
||||
}
|
||||
|
||||
// Copy file to blobs directory.
|
||||
mode := os.O_RDWR | os.O_CREATE
|
||||
if err == nil && info.Size() > size { // shouldn't happen but fix in case
|
||||
mode |= os.O_TRUNC
|
||||
}
|
||||
f, err := os.OpenFile(name, mode, 0666)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
if size == 0 {
|
||||
// File now exists with correct size.
|
||||
// Only one possible zero-length file, so contents are OK too.
|
||||
// Early return here makes sure there's a "last byte" for code below.
|
||||
return nil
|
||||
}
|
||||
|
||||
// From here on, if any of the I/O writing the file fails,
|
||||
// we make a best-effort attempt to truncate the file f
|
||||
// before returning, to avoid leaving bad bytes in the file.
|
||||
|
||||
// Copy file to f, but also into h to double-check hash.
|
||||
if _, err := file.Seek(0, 0); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h := sha256.New()
|
||||
w := io.MultiWriter(f, h)
|
||||
if _, err := io.CopyN(w, file, size-1); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
// Check last byte before writing it; writing it will make the size match
|
||||
// what other processes expect to find and might cause them to start
|
||||
// using the file.
|
||||
buf := make([]byte, 1)
|
||||
if _, err := file.Read(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
h.Write(buf)
|
||||
sum := h.Sum(nil)
|
||||
if !bytes.Equal(sum, out.a[:]) {
|
||||
f.Truncate(0)
|
||||
return fmt.Errorf("file content changed underfoot")
|
||||
}
|
||||
|
||||
// Commit manifest entry.
|
||||
if _, err := f.Write(buf); err != nil {
|
||||
f.Truncate(0)
|
||||
return err
|
||||
}
|
||||
if err := f.Close(); err != nil {
|
||||
// Data might not have been written,
|
||||
// but file may look like it is the right size.
|
||||
// To be extra careful, remove stored file.
|
||||
os.Remove(name)
|
||||
return err
|
||||
}
|
||||
os.Chtimes(name, s.now(), s.now()) // mainly for tests
|
||||
|
||||
return nil
|
||||
}
|
||||
54
x/build/internal/blobstore/blob_test.go
Normal file
54
x/build/internal/blobstore/blob_test.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseID(t *testing.T) {
|
||||
const valid = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
|
||||
var invalid = strings.Repeat("\x00", HashSize*2)
|
||||
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{"", invalid},
|
||||
{"sha256-", invalid},
|
||||
{"sha256-" + valid, valid},
|
||||
|
||||
{"" + valid, invalid}, // no prefix
|
||||
{"sha123-" + valid, invalid}, // invalid prefix
|
||||
{"sha256-" + valid[1:], invalid}, // too short
|
||||
{"sha256-" + valid + "a", invalid}, // too long
|
||||
{"sha256-!" + valid[1:], invalid}, // invalid hex
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run("", func(t *testing.T) {
|
||||
// sanity check
|
||||
if len(tt.want) > HashSize*2 {
|
||||
panic("invalid test")
|
||||
}
|
||||
|
||||
got := ParseID(tt.in)
|
||||
|
||||
wantValid := tt.want != invalid
|
||||
if wantValid {
|
||||
if !got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = false; want true", tt.in)
|
||||
}
|
||||
if got.String() != "sha256-"+tt.want {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "sha256-"+tt.want)
|
||||
}
|
||||
} else {
|
||||
if got.Valid() {
|
||||
t.Errorf("ParseID(%q).Valid() = true; want false", tt.in)
|
||||
}
|
||||
if got.String() != "" {
|
||||
t.Errorf("ParseID(%q).String() = %q; want %q", tt.in, got.String(), "")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
128
x/build/internal/blobstore/store_test.go
Normal file
128
x/build/internal/blobstore/store_test.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package blobstore
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"iter"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/x/model"
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
const (
|
||||
blobNameHello = "sha256-2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
|
||||
)
|
||||
|
||||
func TestStoreBasicBlob(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
checkDir(t, dir, nil)
|
||||
|
||||
st, err := Open(dir)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
st.now = func() time.Time { return now }
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
})
|
||||
|
||||
id, size, err := PutBytes(st, []byte("hello"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if id != ParseID(blobNameHello) {
|
||||
t.Errorf("unexpected ID: %s", id)
|
||||
}
|
||||
if size != 5 {
|
||||
t.Errorf("unexpected size: %d", size)
|
||||
}
|
||||
|
||||
checkDir(t, dir, []string{
|
||||
"blobs/",
|
||||
"blobs/" + blobNameHello,
|
||||
})
|
||||
|
||||
got, err := st.Get(id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, got, Entry{
|
||||
ID: id,
|
||||
Size: 5,
|
||||
Time: now,
|
||||
})
|
||||
|
||||
file := st.OutputFilename(id)
|
||||
wantFile := filepath.Join(dir, "blobs", blobNameHello)
|
||||
if file != wantFile {
|
||||
t.Errorf("unexpected file: %s", file)
|
||||
}
|
||||
|
||||
// Check tags
|
||||
name := model.ParseName("registry.ollama.ai/library/test:latest+KQED")
|
||||
|
||||
t.Logf("RESOLVING: %q", name.Parts())
|
||||
|
||||
}
|
||||
|
||||
// checkDir checks that the directory at dir contains the files in want. The
|
||||
// files in want must be relative to dir.
|
||||
//
|
||||
// direcotories are suffixed with a slash (e.g. "foo/" instead of "foo").
|
||||
//
|
||||
// want must be in lexicographic order.
|
||||
func checkDir(t testing.TB, dir string, want []string) {
|
||||
t.Helper()
|
||||
|
||||
var matches []string
|
||||
for path, err := range walkDir(dir) {
|
||||
t.Helper()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
t.Logf("found %s", path)
|
||||
if path == "./" {
|
||||
continue
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
matches = append(matches, path)
|
||||
}
|
||||
|
||||
diff.Test(t, t.Errorf, matches, want)
|
||||
}
|
||||
|
||||
var errStop = errors.New("stop")
|
||||
|
||||
func walkDir(dir string) iter.Seq2[string, error] {
|
||||
return func(yield func(string, error) bool) {
|
||||
err := filepath.WalkDir(dir, func(path string, info os.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path, err = filepath.Rel(dir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
path = filepath.ToSlash(path)
|
||||
if info.IsDir() {
|
||||
path += "/"
|
||||
}
|
||||
if !yield(path, nil) {
|
||||
return errStop
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if !errors.Is(err, errStop) && err != nil {
|
||||
yield("", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
31
x/client/ollama/apitype/apitype.go
Normal file
31
x/client/ollama/apitype/apitype.go
Normal file
@@ -0,0 +1,31 @@
|
||||
package apitype
|
||||
|
||||
import "time"
|
||||
|
||||
type Message struct {
|
||||
Role string `json:"role"`
|
||||
Content string `json:"content"`
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
Ref string `json:"ref"`
|
||||
Digest string `json:"digest"`
|
||||
Size int64 `json:"size"`
|
||||
ModifiedAt int64 `json:"modified"`
|
||||
}
|
||||
|
||||
func (m Model) Modifed() time.Time {
|
||||
return time.Unix(0, m.ModifiedAt)
|
||||
}
|
||||
|
||||
type PushRequest struct {
|
||||
Name string `json:"name"` // Ref is the official term, "name" is for backward compatibility with exiting clients.
|
||||
Insecure bool `json:"insecure"`
|
||||
Stream bool `json:"stream"`
|
||||
}
|
||||
|
||||
type PushStatus struct {
|
||||
Status string `json:"status"`
|
||||
Digest string `json:"digest"`
|
||||
Total int64 `json:"total"`
|
||||
}
|
||||
173
x/client/ollama/ollama.go
Normal file
173
x/client/ollama/ollama.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package ollama
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/fs"
|
||||
"iter"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/client/ollama/apitype"
|
||||
"github.com/ollama/ollama/x/types/empty"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): PROGRESS INDICATORS!!!!
|
||||
|
||||
const DefaultBaseURL = "http://localhost:11434"
|
||||
|
||||
var envBaseURL = cmp.Or(os.Getenv("OLLAMA_BASE_URL"), DefaultBaseURL)
|
||||
|
||||
// Default returns a new client with the default base URL.
|
||||
func Default() *Client {
|
||||
return &Client{BaseURL: envBaseURL}
|
||||
}
|
||||
|
||||
// I_Acknowledge_This_API_Is_Under_Development is a flag that must be set to
|
||||
// true for any instance of Client to work.
|
||||
var I_Acknowledge_This_API_Is_Under_Development bool
|
||||
|
||||
// Client is a client for the Ollama API.
|
||||
type Client struct {
|
||||
// BaseURL is the base URL of the Ollama API.
|
||||
BaseURL string
|
||||
|
||||
HTTPClient *http.Client // The HTTP client to use. If nil, http.DefaultClient is used.
|
||||
}
|
||||
|
||||
// Build requests the remote Ollama service to build a model. It uploads any
|
||||
// source files the server needs.
|
||||
func (c *Client) Build(ctx context.Context, ref string, modelfile []byte, source fs.FS) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
// Push requests the remote Ollama service to push a model to the server.
|
||||
func (c *Client) Push(ctx context.Context, ref string) error {
|
||||
_, err := Do[empty.Message](ctx, c, "POST", "/v1/push", apitype.PushRequest{Name: ref})
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Pull(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) List(ctx context.Context) iter.Seq2[apitype.Model, error] {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Show(ctx context.Context, ref string) (*apitype.Model, error) {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Remove(ctx context.Context, ref string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Copy(ctx context.Context, dstRef, srcRef string) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (c *Client) Run(ctx context.Context, ref string, messages []apitype.Message) error {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
type Error struct {
|
||||
// Status is the HTTP status code returned by the server.
|
||||
Status int `json:"status"`
|
||||
|
||||
// Code specifies a machine readable code indicating the class of
|
||||
// error this error is. See http://docs.ollama.com/errors for a full
|
||||
// list of error codes.
|
||||
Code string `json:"code"`
|
||||
|
||||
// Message is a humage readable message that describes the error. It
|
||||
// may change across versions of the API, so it should not be used for
|
||||
// programmatic decisions.
|
||||
Message string `json:"message,omitempty"`
|
||||
|
||||
// Field is the field in the request that caused the error, if any.
|
||||
Field string `json:"field,omitempty"`
|
||||
|
||||
// Value is the value of the field that caused the error, if any.
|
||||
Value string `json:"value,omitempty"`
|
||||
}
|
||||
|
||||
func (e *Error) Error() string {
|
||||
var b strings.Builder
|
||||
b.WriteString("ollama: ")
|
||||
b.WriteString(e.Code)
|
||||
if e.Field != "" {
|
||||
b.WriteString(" ")
|
||||
b.WriteString(e.Field)
|
||||
}
|
||||
if e.Value != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Value)
|
||||
}
|
||||
if e.Message != "" {
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Message)
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// Do encodes in and sends it in a request to the Ollama server and decodes
|
||||
// the response into Res, or an error response (non-2xx) into an *Error, or
|
||||
// any error encounted decoding the response.
|
||||
func Do[Res any](ctx context.Context, c *Client, method, path string, in any) (*Res, error) {
|
||||
var body bytes.Buffer
|
||||
// TODO(bmizerany): pool and reuse this buffer AND the encoder
|
||||
if err := encodeJSON(&body, in); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
urlStr := c.BaseURL + path
|
||||
req, err := http.NewRequestWithContext(ctx, method, urlStr, &body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
hc := cmp.Or(c.HTTPClient, http.DefaultClient)
|
||||
res, err := hc.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
var buf bytes.Buffer
|
||||
body := io.TeeReader(res.Body, &buf)
|
||||
e, err := decodeJSON[Error](body)
|
||||
if err != nil {
|
||||
err := fmt.Errorf("ollama: invalid error response from server (status %d): %q", res.StatusCode, buf.String())
|
||||
return nil, err
|
||||
}
|
||||
return nil, e
|
||||
}
|
||||
|
||||
return decodeJSON[Res](res.Body)
|
||||
}
|
||||
|
||||
// decodeJSON decodes JSON from r into a new value of type T.
|
||||
//
|
||||
// NOTE: This is (and encodeJSON) are copies and paste from oweb.go, please
|
||||
// do not try and consolidate so we can keep ollama/client free from
|
||||
// dependencies which are moving targets and not pulling enough weight to
|
||||
// justify their inclusion.
|
||||
func decodeJSON[T any](r io.Reader) (*T, error) {
|
||||
var v T
|
||||
if err := json.NewDecoder(r).Decode(&v); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
// NOTE: see NOT above decodeJSON
|
||||
func encodeJSON(w io.Writer, v any) error {
|
||||
// TODO(bmizerany): pool and reuse encoder
|
||||
return json.NewEncoder(w).Encode(v)
|
||||
}
|
||||
100
x/cmd/bllamo/bllamo.go
Normal file
100
x/cmd/bllamo/bllamo.go
Normal file
@@ -0,0 +1,100 @@
|
||||
// Bllamo is a (new) tool for managing Ollama models.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// bllamo <command> [arguments]
|
||||
//
|
||||
// The commands are:
|
||||
//
|
||||
// build build a model from a Modelfile
|
||||
// list list all models
|
||||
// push push a model from an ollama registry
|
||||
// pull pull a model from an ollama registry
|
||||
// delete delete a model from an ollama registry
|
||||
// help display help for a command
|
||||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
|
||||
"github.com/ollama/ollama/x/api"
|
||||
"github.com/ollama/ollama/x/build"
|
||||
"github.com/ollama/ollama/x/client/ollama"
|
||||
"github.com/ollama/ollama/x/registry"
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
args := flag.Args()
|
||||
if len(args) < 1 {
|
||||
fmt.Fprintln(os.Stderr, "bllamo: no command provided")
|
||||
os.Exit(2)
|
||||
}
|
||||
if err := Main(flag.Args()...); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "%v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
var TODOUsage = fmt.Errorf("TODO: usage")
|
||||
|
||||
var commands = map[string]func(ctx context.Context, args ...string) error{
|
||||
"build": cmdBuild,
|
||||
"push": cmdPush,
|
||||
"serve": cmdServe,
|
||||
"registry": cmdRegistry,
|
||||
}
|
||||
|
||||
// Main is the entry point for the blammo command.
|
||||
func Main(args ...string) error {
|
||||
cmd := args[0]
|
||||
args = args[1:]
|
||||
if f, ok := commands[cmd]; ok {
|
||||
ctx := context.TODO()
|
||||
return f(ctx, args...)
|
||||
}
|
||||
return fmt.Errorf("blammo: unknown command %q", cmd)
|
||||
}
|
||||
|
||||
func cmdBuild(ctx context.Context, args ...string) error {
|
||||
var v struct {
|
||||
Modelfile string `flag:"f,the Modelfile to use"`
|
||||
}
|
||||
|
||||
fs := readFlags("build", args, &v)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
|
||||
modelfile, err := os.ReadFile(cmp.Or(v.Modelfile, "Modelfile"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ollama.Default().Build(ctx, args[0], modelfile, os.DirFS("."))
|
||||
}
|
||||
|
||||
func cmdRegistry(_ context.Context, _ ...string) error {
|
||||
var s registry.Server
|
||||
return http.ListenAndServe(":8888", &s)
|
||||
}
|
||||
|
||||
func cmdServe(ctx context.Context, args ...string) error {
|
||||
bs, err := build.Open("")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return http.ListenAndServe(":11434", &api.Server{Build: bs})
|
||||
}
|
||||
|
||||
func cmdPush(ctx context.Context, args ...string) error {
|
||||
fs := readFlags("push", args, nil)
|
||||
if fs.NArg() != 1 {
|
||||
return TODOUsage
|
||||
}
|
||||
return ollama.Default().Push(ctx, fs.Arg(0))
|
||||
}
|
||||
59
x/cmd/bllamo/flags.go
Normal file
59
x/cmd/bllamo/flags.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// parseArgs parses the provided args using a flag.FlagSet that is
|
||||
// dynamically build using reflection for the provided type. The type fields
|
||||
// that have a "flag" tag are used to build the flags. The flag tag should
|
||||
// include either a ('-'). Example usage:
|
||||
//
|
||||
// func main() {
|
||||
// var flags struct {
|
||||
// Modelfile string `flag:"f,path to the Modelfile"`
|
||||
// }
|
||||
//
|
||||
// fs := readFlags(os.Args[1:], &flags)
|
||||
// fs.Parse(os.Args[1:])
|
||||
// }
|
||||
func readFlags(name string, args []string, v any) *flag.FlagSet {
|
||||
fs := flag.NewFlagSet(name, flag.ExitOnError)
|
||||
defer fs.Parse(args)
|
||||
if v == nil {
|
||||
return fs
|
||||
}
|
||||
|
||||
for i := 0; i < reflect.ValueOf(v).NumField(); i++ {
|
||||
f := reflect.ValueOf(v).Field(i)
|
||||
if !f.CanSet() {
|
||||
continue
|
||||
}
|
||||
|
||||
tag := f.Type().Field(i).Tag.Get("flag")
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
var name, usage string
|
||||
if i := strings.Index(tag, ","); i != -1 {
|
||||
name = tag[:i]
|
||||
usage = tag[i+1:]
|
||||
} else {
|
||||
name = tag
|
||||
}
|
||||
|
||||
// TODO(bmizerany): add more types as needed
|
||||
switch f.Kind() {
|
||||
case reflect.String:
|
||||
fs.StringVar(f.Addr().Interface().(*string), name, "", usage)
|
||||
case reflect.Bool:
|
||||
fs.BoolVar(f.Addr().Interface().(*bool), name, false, usage)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported type %v", f.Kind()))
|
||||
}
|
||||
}
|
||||
return fs
|
||||
}
|
||||
97
x/cmd/gguf/gguf.go
Normal file
97
x/cmd/gguf/gguf.go
Normal file
@@ -0,0 +1,97 @@
|
||||
// Gguf is a tool for learning about GGUF files.
|
||||
//
|
||||
// Usage:
|
||||
//
|
||||
// gguf [flags] <file>
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"text/tabwriter"
|
||||
|
||||
"github.com/ollama/ollama/x/encoding/gguf"
|
||||
)
|
||||
|
||||
func main() {
|
||||
if err := Main(os.Stdout, os.Args[1:]...); err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func Main(stdout io.Writer, args ...string) error {
|
||||
fs := flag.NewFlagSet("gguf", flag.ExitOnError)
|
||||
flagGPU := fs.Uint64("gpu", 0, "use N bytes of GPU memory (default is 0)")
|
||||
|
||||
fs.Usage = func() {
|
||||
io.WriteString(stdout, "Gguf is a tool for learning about GGUF files.\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "Usage:\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
io.WriteString(stdout, "\tgguf [flags] <file>\n")
|
||||
io.WriteString(stdout, "\n")
|
||||
var numFlags int
|
||||
fs.VisitAll(func(*flag.Flag) { numFlags++ })
|
||||
if numFlags > 0 {
|
||||
io.WriteString(stdout, "Flags:\n")
|
||||
fs.PrintDefaults()
|
||||
}
|
||||
}
|
||||
fs.Parse(args)
|
||||
|
||||
if fs.NArg() != 1 {
|
||||
fs.Usage()
|
||||
os.Exit(2)
|
||||
}
|
||||
|
||||
file := fs.Arg(0)
|
||||
f, err := os.Open(file)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
g, err := gguf.ReadFile(f)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
tw := tabwriter.NewWriter(stdout, 0, 2, 2, ' ', 0)
|
||||
defer tw.Flush()
|
||||
|
||||
fmt.Fprintf(tw, "version:\t%d\n", g.Version())
|
||||
|
||||
for m, err := range g.Metadata {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
if len(m.Values) > 5 {
|
||||
fmt.Fprintf(tw, "meta:\t%q: ... (%d values)\n", m.Key, len(m.Values))
|
||||
} else {
|
||||
fmt.Fprintf(tw, "meta:\t%q: %v\n", m.Key, m.Values)
|
||||
}
|
||||
}
|
||||
|
||||
var i int
|
||||
var totalLayerBytes uint64
|
||||
var offGPU bool
|
||||
for t, err := range g.Tensors {
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
totalLayerBytes += t.Size
|
||||
if totalLayerBytes > *flagGPU {
|
||||
offGPU = true
|
||||
}
|
||||
|
||||
const msg = "tensor (layer %000d):\t%q\t%s\tdims=%v\toffset=%d\tsize=%d\tonGPU=%v\n"
|
||||
fmt.Fprintf(tw, msg, i, t.Name, t.Type, t.Dimensions, t.Offset, t.Size, !offGPU)
|
||||
|
||||
i++
|
||||
}
|
||||
return nil
|
||||
}
|
||||
376
x/encoding/gguf/gguf.go
Normal file
376
x/encoding/gguf/gguf.go
Normal file
@@ -0,0 +1,376 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// TODO(bmizerany): determine a more reasonable value for MaxDimensions.
|
||||
|
||||
// MaxDimensions is the maximum number of dimensions a tensor can have.
|
||||
const MaxDimensions uint32 = 1e6
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrBadMagic is returned when the magic bytes at the start of the
|
||||
// file. This is useful for detecting if the file is not a gguf
|
||||
// file.
|
||||
ErrBadMagic = errors.New("gguf: bad magic")
|
||||
|
||||
ErrUnsupportedVersion = errors.New("gguf: unsupported version")
|
||||
ErrMangled = errors.New("gguf: mangled data")
|
||||
)
|
||||
|
||||
type Type uint32
|
||||
|
||||
const (
|
||||
TypeF32 Type = 0
|
||||
TypeF16 Type = 1
|
||||
TypeQ4_0 Type = 2
|
||||
TypeQ4_1 Type = 3
|
||||
TypeQ5_0 Type = 6
|
||||
TypeQ5_1 Type = 7
|
||||
TypeQ8_0 Type = 8
|
||||
TypeQ8_1 Type = 9
|
||||
TypeQ2_K Type = 10
|
||||
TypeQ3_K Type = 11
|
||||
TypeQ4_K Type = 12
|
||||
TypeQ5_K Type = 13
|
||||
TypeQ6_K Type = 14
|
||||
TypeQ8_K Type = 15
|
||||
TypeI8 Type = 16
|
||||
TypeI16 Type = 17
|
||||
TypeI32 Type = 18
|
||||
TypeCount Type = 19
|
||||
)
|
||||
|
||||
var typeNames = map[Type]string{
|
||||
TypeF32: "F32",
|
||||
TypeF16: "F16",
|
||||
TypeQ4_0: "Q4_0",
|
||||
TypeQ4_1: "Q4_1",
|
||||
TypeQ5_0: "Q5_0",
|
||||
TypeQ5_1: "Q5_1",
|
||||
TypeQ8_0: "Q8_0",
|
||||
TypeQ8_1: "Q8_1",
|
||||
TypeQ2_K: "Q2_K",
|
||||
TypeQ3_K: "Q3_K",
|
||||
TypeQ4_K: "Q4_K",
|
||||
TypeQ5_K: "Q5_K",
|
||||
TypeQ6_K: "Q6_K",
|
||||
TypeQ8_K: "Q8_K",
|
||||
TypeI8: "I8",
|
||||
TypeI16: "I16",
|
||||
TypeI32: "I32",
|
||||
TypeCount: "COUNT",
|
||||
}
|
||||
|
||||
func (t Type) String() string {
|
||||
if name := typeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_type %d!)", t)
|
||||
}
|
||||
|
||||
// ValueType is the type of a metadata value.
|
||||
type ValueType uint32
|
||||
|
||||
func (t ValueType) String() string {
|
||||
if name := metaTypeNames[t]; name != "" {
|
||||
return name
|
||||
}
|
||||
return fmt.Sprintf("(!unknown_value_type %d!)", t)
|
||||
}
|
||||
|
||||
const (
|
||||
ValueTypeUint8 ValueType = 0
|
||||
ValueTypeInt8 ValueType = 1
|
||||
ValueTypeUint16 ValueType = 2
|
||||
ValueTypeInt16 ValueType = 3
|
||||
ValueTypeUint32 ValueType = 4
|
||||
ValueTypeInt32 ValueType = 5
|
||||
ValueTypeFloat32 ValueType = 6
|
||||
ValueTypeBool ValueType = 7
|
||||
ValueTypeString ValueType = 8
|
||||
ValueTypeArray ValueType = 9
|
||||
ValueTypeUint64 ValueType = 10
|
||||
ValueTypeInt64 ValueType = 11
|
||||
ValueTypeFloat64 ValueType = 12
|
||||
)
|
||||
|
||||
var metaTypeNames = map[ValueType]string{
|
||||
ValueTypeUint8: "uint8",
|
||||
ValueTypeInt8: "int8",
|
||||
ValueTypeUint16: "uint16",
|
||||
ValueTypeInt16: "int16",
|
||||
ValueTypeUint32: "uint32",
|
||||
ValueTypeInt32: "int32",
|
||||
ValueTypeFloat32: "float32",
|
||||
ValueTypeBool: "bool",
|
||||
ValueTypeString: "string",
|
||||
ValueTypeArray: "array",
|
||||
ValueTypeUint64: "uint64",
|
||||
ValueTypeInt64: "int64",
|
||||
ValueTypeFloat64: "float64",
|
||||
}
|
||||
|
||||
type TensorInfo struct {
|
||||
Name string
|
||||
Dimensions []uint64
|
||||
Type Type
|
||||
Offset uint64
|
||||
Size uint64
|
||||
}
|
||||
|
||||
type MetaValue struct {
|
||||
Type ValueType
|
||||
Value []byte
|
||||
}
|
||||
|
||||
func (v MetaValue) String() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(v.Type.String())
|
||||
b.WriteString("(")
|
||||
switch v.Type {
|
||||
case ValueTypeArray:
|
||||
b.WriteString("[...]")
|
||||
case ValueTypeString:
|
||||
b.WriteString(strconv.Quote(string(v.Value)))
|
||||
case ValueTypeBool:
|
||||
if len(v.Value) == 0 {
|
||||
b.WriteString("(!invalid bool)")
|
||||
}
|
||||
switch v.Value[0] {
|
||||
case 0:
|
||||
b.WriteString("false")
|
||||
case 1:
|
||||
b.WriteString("true")
|
||||
default:
|
||||
b.WriteString("!invalid bool")
|
||||
}
|
||||
case ValueTypeUint8, ValueTypeInt8, ValueTypeUint16, ValueTypeInt16, ValueTypeUint32, ValueTypeInt32, ValueTypeUint64, ValueTypeInt64, ValueTypeFloat32, ValueTypeFloat64:
|
||||
var buf [8]byte
|
||||
if len(v.Value) < 8 {
|
||||
copy(buf[:], v.Value)
|
||||
}
|
||||
fmt.Fprintf(&b, "%v", binary.LittleEndian.Uint64(buf[:]))
|
||||
default:
|
||||
fmt.Fprintf(&b, "%v", v.Value)
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type MetaEntry struct {
|
||||
Key string
|
||||
Type ValueType
|
||||
Values []MetaValue
|
||||
}
|
||||
|
||||
func (e MetaEntry) String() string {
|
||||
if len(e.Values) == 0 {
|
||||
return ""
|
||||
}
|
||||
return string(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) Uint32() uint32 {
|
||||
if len(e.Values) == 0 {
|
||||
return 0
|
||||
}
|
||||
return binary.LittleEndian.Uint32(e.Values[0].Value)
|
||||
}
|
||||
|
||||
func (e MetaEntry) FileType() Type {
|
||||
if len(e.Values) == 0 {
|
||||
return TypeCount
|
||||
}
|
||||
return Type(e.Uint32())
|
||||
}
|
||||
|
||||
func (e MetaEntry) GoString() string {
|
||||
var b strings.Builder
|
||||
b.WriteString(e.Key)
|
||||
b.WriteString(": ")
|
||||
b.WriteString(e.Type.String())
|
||||
b.WriteString("(")
|
||||
for i, v := range e.Values {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteString(v.String())
|
||||
}
|
||||
b.WriteString(")")
|
||||
return b.String()
|
||||
}
|
||||
|
||||
type Info struct {
|
||||
_ structs.Incomparable // prevent comparison of Info values so we can change the implementation later
|
||||
|
||||
Version int
|
||||
FileType Type
|
||||
}
|
||||
|
||||
func Stat(path string) (Info, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
defer f.Close()
|
||||
return StatReader(f)
|
||||
}
|
||||
|
||||
// StatReader reads the header information from r and returns an Info
|
||||
// struct with the version and file type.
|
||||
//
|
||||
// It returns an error if any.
|
||||
//
|
||||
// As a special case, it returns ErrBadMagic if the file does not start with
|
||||
// the magic bytes. This can be used to detect if the file is not a GGUF
|
||||
// file.
|
||||
func StatReader(r io.ReadSeeker) (Info, error) {
|
||||
if _, err := r.Seek(0, 0); err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
f, err := ReadFile(r)
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
info := Info{Version: f.Version()}
|
||||
for m, err := range f.Metadata {
|
||||
if err != nil {
|
||||
return Info{}, err
|
||||
}
|
||||
if m.Key == "general.file_type" {
|
||||
if m.Type != ValueTypeUint32 {
|
||||
return Info{}, fmt.Errorf("unexpected type for metadata key %q: %v, want %v", m.Key, m.Type, ValueTypeUint32)
|
||||
}
|
||||
info.FileType = m.FileType()
|
||||
}
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
type File struct {
|
||||
version uint32
|
||||
numMetaValues uint64
|
||||
numTensors uint64
|
||||
|
||||
gr *ggufReader
|
||||
}
|
||||
|
||||
// ReadFile reads header information from r and returns a File, ready for
|
||||
// iteration over Metadata and Tensors.
|
||||
func ReadFile(r io.Reader) (*File, error) {
|
||||
f, err := readFile(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func (f *File) Version() int {
|
||||
return int(f.version)
|
||||
}
|
||||
|
||||
// Metadata iterates over the metadata in the file. It must be exhausted
|
||||
// before calling Tensors.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Metadata(yield func(MetaEntry, error) bool) {
|
||||
var n int
|
||||
for range f.numMetaValues {
|
||||
meta, err := f.gr.readMetaEntry()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata entry %d: %w", n, err)
|
||||
yield(MetaEntry{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(meta, nil) {
|
||||
return
|
||||
}
|
||||
n++
|
||||
}
|
||||
}
|
||||
|
||||
// Tensors iterates over the tensors in the file. It must only be called
|
||||
// after exhausting the metadata iterator.
|
||||
//
|
||||
// It is not resumable.
|
||||
func (f *File) Tensors(yield func(TensorInfo, error) bool) {
|
||||
var last TensorInfo
|
||||
for range f.numTensors {
|
||||
info, err := f.gr.readTensorInfo()
|
||||
|
||||
// If the last tensor had a valid offset, yield it.
|
||||
//
|
||||
// NOTE: No tensor should have an offset of 0 because the
|
||||
// offset is the start of the tensor data which is always
|
||||
// afer the magic bytes, version, numMetaValues, and
|
||||
// numTensors, which MUST all be non-zero bytes as per the
|
||||
// GGUF spec.
|
||||
if last.Offset > 0 {
|
||||
if !yield(last, err) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
yield(TensorInfo{}, err)
|
||||
return
|
||||
}
|
||||
|
||||
// Tensor data does not include size, so we need to
|
||||
// calculate it based on the offset of the previous tensor
|
||||
// offset to the current.
|
||||
offset0 := last.Offset
|
||||
last = info
|
||||
last.Size = info.Offset - offset0
|
||||
}
|
||||
if last.Offset > 0 {
|
||||
yield(last, nil)
|
||||
}
|
||||
}
|
||||
|
||||
var magicBytes = []byte{0x47, 0x47, 0x55, 0x46}
|
||||
|
||||
func readFile(r io.Reader) (*File, error) {
|
||||
gr := &ggufReader{r: &reader{r: r}}
|
||||
magic, err := gr.next(4)
|
||||
if err != nil {
|
||||
return nil, errors.Join(err, ErrBadMagic)
|
||||
}
|
||||
if !bytes.Equal(magic, magicBytes) {
|
||||
return nil, ErrBadMagic
|
||||
}
|
||||
version, err := gr.readUint32()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if version != 3 {
|
||||
return nil, fmt.Errorf("%w: %d", ErrUnsupportedVersion, version)
|
||||
}
|
||||
numTensors, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
numMetaValues, err := gr.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
info := &File{
|
||||
version: version,
|
||||
|
||||
numMetaValues: numMetaValues,
|
||||
numTensors: numTensors,
|
||||
gr: gr,
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
345
x/encoding/gguf/gguf_test.go
Normal file
345
x/encoding/gguf/gguf_test.go
Normal file
@@ -0,0 +1,345 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"io"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"kr.dev/diff"
|
||||
)
|
||||
|
||||
func TestStat(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
wantInfo Info
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "valid general.file_type",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// general.file_type key
|
||||
"\x11\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"general.file_type" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // uint32 value
|
||||
"",
|
||||
wantInfo: Info{
|
||||
Version: 3,
|
||||
FileType: 1,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
info, err := StatReader(strings.NewReader(tt.data))
|
||||
if tt.wantErr != nil {
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("err = %v; want %q", err, tt.wantErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
diff.Test(t, t.Errorf, info, tt.wantInfo)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadInfo(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
data string
|
||||
|
||||
wantMeta []MetaEntry
|
||||
wantTensor []TensorInfo
|
||||
wantReadErr error
|
||||
wantMetaErr error
|
||||
wantTensorErr error
|
||||
wantInfo Info
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
wantReadErr: io.ErrUnexpectedEOF,
|
||||
},
|
||||
{
|
||||
name: "bad magic",
|
||||
data: "\xBB\xAA\xDD\x00",
|
||||
wantReadErr: ErrBadMagic,
|
||||
},
|
||||
{
|
||||
name: "bad version",
|
||||
data: string(magicBytes) +
|
||||
"\x02\x00\x00\x00" + // version
|
||||
"",
|
||||
wantReadErr: ErrUnsupportedVersion,
|
||||
},
|
||||
{
|
||||
name: "no metadata or tensors",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"",
|
||||
wantReadErr: nil,
|
||||
},
|
||||
{
|
||||
name: "good metadata",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "K", Type: ValueTypeString, Values: []MetaValue{{Type: ValueTypeString, Value: []byte("VV")}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "good metadata with multiple values",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// MetaEntry 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"x" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"XX" + // string value
|
||||
|
||||
// MetaEntry 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"y" + // key
|
||||
"\x04\x00\x00\x00" + // type (uint32)
|
||||
"\x99\x88\x77\x66" + // uint32 value
|
||||
"",
|
||||
wantMeta: []MetaEntry{
|
||||
{Key: "x", Type: ValueTypeString, Values: []MetaValue{{
|
||||
Type: ValueTypeString,
|
||||
Value: []byte("XX"),
|
||||
}}},
|
||||
{Key: "y", Type: ValueTypeUint32, Values: []MetaValue{{
|
||||
Type: ValueTypeUint32,
|
||||
Value: []byte{0x99, 0x88, 0x77, 0x66},
|
||||
}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "negative string length in meta key",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"",
|
||||
wantMetaErr: ErrMangled,
|
||||
},
|
||||
|
||||
// Tensor tests
|
||||
{
|
||||
name: "good tensor",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
// dimensions
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "too many dimensions",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
|
||||
"\x00\x00\x00\x01" + // dimensions length
|
||||
"",
|
||||
wantTensorErr: ErrMangled,
|
||||
},
|
||||
{
|
||||
name: "size computed",
|
||||
data: string(magicBytes) + // magic
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
|
||||
// Tensor 1
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x01\x00\x00\x00\x00\x00\x00" + // offset
|
||||
|
||||
// Tensor 2
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x00\x03\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"",
|
||||
wantTensor: []TensorInfo{
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 256,
|
||||
Size: 256,
|
||||
},
|
||||
{
|
||||
Name: "t",
|
||||
Dimensions: []uint64{1},
|
||||
Type: TypeQ4_1,
|
||||
Offset: 768,
|
||||
Size: 512,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
f, err := ReadFile(strings.NewReader(tt.data))
|
||||
if err != nil {
|
||||
if !errors.Is(err, tt.wantReadErr) {
|
||||
t.Fatalf("unexpected ReadFile error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var got []MetaEntry
|
||||
for meta, err := range f.Metadata {
|
||||
if !errors.Is(err, tt.wantMetaErr) {
|
||||
t.Fatalf("err = %v; want %v", err, ErrMangled)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got = append(got, meta)
|
||||
}
|
||||
diff.Test(t, t.Errorf, got, tt.wantMeta)
|
||||
|
||||
var gotT []TensorInfo
|
||||
for tinfo, err := range f.Tensors {
|
||||
if !errors.Is(err, tt.wantTensorErr) {
|
||||
t.Fatalf("err = %v; want %v", err, tt.wantTensorErr)
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
gotT = append(gotT, tinfo)
|
||||
}
|
||||
diff.Test(t, t.Errorf, gotT, tt.wantTensor)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func FuzzReadInfo(f *testing.F) {
|
||||
f.Add(string(magicBytes))
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x00\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"")
|
||||
f.Add(string(magicBytes) +
|
||||
"\x03\x00\x00\x00" + // version
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numMetaValues
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // numTensors
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // key length
|
||||
"K" + // key
|
||||
"\x08\x00\x00\x00" + // type (string)
|
||||
"\x02\x00\x00\x00\x00\x00\x00\x00" + // string length
|
||||
"VV" + // string value
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // name length
|
||||
"t" +
|
||||
"\x01\x00\x00\x00" + // dimensions length
|
||||
"\x01\x00\x00\x00\x00\x00\x00\x00" + // dimension[0]
|
||||
"\x03\x00\x00\x00" + // type (i8)
|
||||
"\x05\x00\x00\x00\x00\x00\x00\x00" + // offset
|
||||
"")
|
||||
|
||||
f.Fuzz(func(t *testing.T, data string) {
|
||||
gf, err := ReadFile(strings.NewReader(data))
|
||||
if err != nil {
|
||||
t.Logf("ReadFile error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
for _, err := range gf.Metadata {
|
||||
if err != nil {
|
||||
t.Logf("metadata error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
for tinfo, err := range gf.Tensors {
|
||||
if err != nil {
|
||||
t.Logf("tensor error: %v", err)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Offset <= 0 {
|
||||
t.Logf("invalid tensor offset: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
if tinfo.Size <= 0 {
|
||||
t.Logf("invalid tensor size: %+v", t)
|
||||
t.Skip()
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
195
x/encoding/gguf/ggufio.go
Normal file
195
x/encoding/gguf/ggufio.go
Normal file
@@ -0,0 +1,195 @@
|
||||
package gguf
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
)
|
||||
|
||||
type ggufReader struct {
|
||||
r *reader
|
||||
n int
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaEntry() (MetaEntry, error) {
|
||||
key, err := r.readString()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
typ, err := r.readValueType()
|
||||
if err != nil {
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
var values []MetaValue
|
||||
for v, err := range r.readMetaValues(typ) {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("(key=%q type=%s): %w", key, typ, err)
|
||||
return MetaEntry{}, err
|
||||
}
|
||||
values = append(values, v)
|
||||
}
|
||||
return MetaEntry{
|
||||
Key: string(key),
|
||||
Type: typ,
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValue(typ ValueType) (MetaValue, error) {
|
||||
var value []byte
|
||||
var err error
|
||||
switch typ {
|
||||
case ValueTypeUint8, ValueTypeInt8:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeUint16, ValueTypeInt16:
|
||||
value, err = r.next(2)
|
||||
case ValueTypeUint32, ValueTypeInt32, ValueTypeFloat32:
|
||||
value, err = r.next(4)
|
||||
case ValueTypeUint64, ValueTypeInt64, ValueTypeFloat64:
|
||||
value, err = r.next(8)
|
||||
case ValueTypeBool:
|
||||
value, err = r.next(1)
|
||||
case ValueTypeString:
|
||||
value, err = r.readString()
|
||||
case ValueTypeArray:
|
||||
err = fmt.Errorf("nested arrays are not supported")
|
||||
default:
|
||||
err = fmt.Errorf("unsupported metadata type: %d", typ)
|
||||
}
|
||||
if err != nil {
|
||||
return MetaValue{}, err
|
||||
}
|
||||
return MetaValue{
|
||||
Type: typ,
|
||||
Value: bytes.Clone(value),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readMetaValues(typ ValueType) iter.Seq2[MetaValue, error] {
|
||||
return func(yield func(MetaValue, error) bool) {
|
||||
if typ == ValueTypeArray {
|
||||
atyp, err := r.readValueType()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid type: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid length: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
for i := range n {
|
||||
v, err := r.readMetaValue(atyp)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid entry (type=%s) %d: %w", atyp, i, err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
if !yield(v, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
v, err := r.readMetaValue(typ)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error reading metadata value: %w", err)
|
||||
yield(MetaValue{}, err)
|
||||
return
|
||||
}
|
||||
yield(v, nil)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *ggufReader) readValueType() (ValueType, error) {
|
||||
typ, err := r.readUint32()
|
||||
return ValueType(typ), err
|
||||
}
|
||||
|
||||
func (r *ggufReader) readTensorInfo() (TensorInfo, error) {
|
||||
name, err := r.readString()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
numDimensions, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
if numDimensions > MaxDimensions {
|
||||
return TensorInfo{}, fmt.Errorf("%w: dimensions length (%d) exceeds %d", ErrMangled, numDimensions, MaxDimensions)
|
||||
}
|
||||
|
||||
dims := make([]uint64, numDimensions)
|
||||
for i := range dims {
|
||||
d, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
dims[i] = d
|
||||
}
|
||||
typ, err := r.readUint32()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
offset, err := r.readUint64()
|
||||
if err != nil {
|
||||
return TensorInfo{}, err
|
||||
}
|
||||
|
||||
// TODO(bmizerany): check offset is multiple of ALIGNMENT
|
||||
return TensorInfo{
|
||||
Name: string(name),
|
||||
Dimensions: dims,
|
||||
Type: Type(typ),
|
||||
Offset: offset,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) next(n int) ([]byte, error) {
|
||||
if n < 0 {
|
||||
return nil, errors.Join(fmt.Errorf("invalid read length: %d", n), ErrMangled)
|
||||
}
|
||||
w := r.r.window()
|
||||
for len(w) < n {
|
||||
if r.r.extend() == 0 {
|
||||
return nil, io.ErrUnexpectedEOF
|
||||
}
|
||||
w = r.r.window()
|
||||
}
|
||||
r.r.release(n)
|
||||
r.n += n
|
||||
return w[:n], nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readString() ([]byte, error) {
|
||||
n, err := r.readUint64()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// TODO(bmizerany): limit max string length
|
||||
return r.next(int(n))
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint32() (uint32, error) {
|
||||
b, err := r.next(4)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint32(b)
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (r *ggufReader) readUint64() (uint64, error) {
|
||||
b, err := r.next(8)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
n := binary.LittleEndian.Uint64(b)
|
||||
return n, nil
|
||||
}
|
||||
70
x/encoding/gguf/reader.go
Normal file
70
x/encoding/gguf/reader.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package gguf
|
||||
|
||||
import "io"
|
||||
|
||||
// A reader implements a sliding window over an io.Reader.
|
||||
type reader struct {
|
||||
data []byte
|
||||
offset int
|
||||
r io.Reader
|
||||
err error
|
||||
}
|
||||
|
||||
// release discards n bytes from the front of the window.
|
||||
func (b *reader) release(n int) {
|
||||
b.offset += n
|
||||
}
|
||||
|
||||
// window returns the current window.
|
||||
// The window is invalidated by calls to release or extend.
|
||||
func (b *reader) window() []byte {
|
||||
return b.data[b.offset:]
|
||||
}
|
||||
|
||||
// tuning constants for byteReader.extend.
|
||||
const (
|
||||
newBufferSize = 8 << 10
|
||||
minReadSize = newBufferSize >> 2
|
||||
)
|
||||
|
||||
// extend extends the window with data from the underlying reader.
|
||||
func (b *reader) extend() int {
|
||||
if b.err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
remaining := len(b.data) - b.offset
|
||||
if remaining == 0 {
|
||||
b.data = b.data[:0]
|
||||
b.offset = 0
|
||||
}
|
||||
if cap(b.data)-len(b.data) >= minReadSize {
|
||||
// nothing to do, enough space exists between len and cap.
|
||||
} else if cap(b.data)-remaining >= minReadSize {
|
||||
// buffer has enough space if we move the data to the front.
|
||||
b.compact()
|
||||
} else {
|
||||
// otherwise, we must allocate/extend a new buffer
|
||||
b.grow()
|
||||
}
|
||||
remaining += b.offset
|
||||
n, err := b.r.Read(b.data[remaining:cap(b.data)])
|
||||
// reduce length to the existing plus the data we read.
|
||||
b.data = b.data[:remaining+n]
|
||||
b.err = err
|
||||
return n
|
||||
}
|
||||
|
||||
// grow grows the buffer, moving the active data to the front.
|
||||
func (b *reader) grow() {
|
||||
buf := make([]byte, max(cap(b.data)*2, newBufferSize))
|
||||
copy(buf, b.data[b.offset:])
|
||||
b.data = buf
|
||||
b.offset = 0
|
||||
}
|
||||
|
||||
// compact moves the active data to the front of the buffer.
|
||||
func (b *reader) compact() {
|
||||
copy(b.data, b.data[b.offset:])
|
||||
b.offset = 0
|
||||
}
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/787da6e90e4be491
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x02\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/8b42c37d144cd2c6
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/92b890e394a77cfc
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xfd\xff\xff\xff\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/9cfd6a48931a2753
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x00K\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x00VV\x01\x00\x00\x00\x00\\x00\\x00\\x00\\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a8c5454e2a164af2
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\xa6\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x10\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/a931e37cb6f932d4
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x00\x00\x00\x800\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x01\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x000\x00\x00\x00\xe3\xe3\xe3\xe3\x00")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/bcd20fa73e7351a2
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000000000000000000\xe5")
|
||||
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
2
x/encoding/gguf/testdata/fuzz/FuzzReadInfo/d29846a68e32052d
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("GGUF\x03\x00\x00\x0000000000\x01\x00\x00\x00\x00\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x000\b\x00\x00\x00\x02\x00\x00\x00\x00\x00\x00\x0000\x01\x00\x00\x00\x00\x00\x00\x000\x01\x00\x001\x01\x00\x00\x00\x00\x00\x00\x00\x03\x00\x00\x00\x05\x00\x00\x00\x00\x00\x00\a")
|
||||
134
x/model/digest.go
Normal file
134
x/model/digest.go
Normal file
@@ -0,0 +1,134 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"strings"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// Digest represents a digest of a model Manifest. It is a comparable value
|
||||
// type and is immutable.
|
||||
//
|
||||
// The zero Digest is not a valid digest.
|
||||
type Digest struct {
|
||||
s string
|
||||
}
|
||||
|
||||
// Type returns the digest type of the digest.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// ParseDigest("sha256-1234").Type() // returns "sha256"
|
||||
func (d Digest) Type() string {
|
||||
typ, _, _ := strings.Cut(d.s, "-")
|
||||
return typ
|
||||
}
|
||||
|
||||
// String returns the digest in the form of "<digest-type>-<digest>", or the
|
||||
// empty string if the digest is invalid.
|
||||
func (d Digest) String() string { return d.s }
|
||||
|
||||
// IsValid returns true if the digest is valid (not zero).
|
||||
//
|
||||
// A valid digest may be created only by ParseDigest, or
|
||||
// ParseName(name).Digest().
|
||||
func (d Digest) IsValid() bool { return d.s != "" }
|
||||
|
||||
// MarshalText implements encoding.TextMarshaler.
|
||||
func (d Digest) MarshalText() ([]byte, error) {
|
||||
return []byte(d.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements encoding.TextUnmarshaler.
|
||||
func (d *Digest) UnmarshalText(text []byte) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal UnmarshalText on valid Digest")
|
||||
}
|
||||
*d = ParseDigest(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
// LogValue implements slog.Value.
|
||||
func (d Digest) LogValue() slog.Value {
|
||||
return slog.StringValue(d.String())
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Digest{}
|
||||
_ sql.Scanner = (*Digest)(nil)
|
||||
_ slog.LogValuer = Digest{}
|
||||
)
|
||||
|
||||
// Scan implements the sql.Scanner interface.
|
||||
func (d *Digest) Scan(src any) error {
|
||||
if d.IsValid() {
|
||||
return errors.New("model.Digest: illegal Scan on valid Digest")
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*d = ParseDigest(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*d = ParseDigest(string(v))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("model.Digest: invalid Scan source %T", src)
|
||||
}
|
||||
|
||||
// Value implements the driver.Valuer interface.
|
||||
func (d Digest) Value() (driver.Value, error) {
|
||||
return d.String(), nil
|
||||
}
|
||||
|
||||
// ParseDigest parses a string in the form of "<digest-type>-<digest>" into a
|
||||
// Digest.
|
||||
func ParseDigest(s string) Digest {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
if ok && isValidDigestType(typ) && isValidHex(digest) {
|
||||
return Digest{s: s}
|
||||
}
|
||||
return Digest{}
|
||||
}
|
||||
|
||||
// isValidDigest returns true if the given string in the form of
|
||||
// "<digest-type>-<digest>", and <digest-type> is in the form of [a-z0-9]+
|
||||
// and <digest> is a valid hex string.
|
||||
//
|
||||
// It does not check if the digest is a valid hash for the given digest
|
||||
// type, or restrict the digest type to a known set of types. This is left
|
||||
// up to ueers of this package.
|
||||
func isValidDigest(s string) bool {
|
||||
typ, digest, ok := strings.Cut(s, "-")
|
||||
res := ok && isValidDigestType(typ) && isValidHex(digest)
|
||||
fmt.Printf("DEBUG: %q: typ: %s, digest: %s, ok: %v res: %v\n", s, typ, digest, ok, res)
|
||||
return res
|
||||
}
|
||||
|
||||
func isValidDigestType(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for _, r := range s {
|
||||
if !unicode.IsLower(r) && !unicode.IsDigit(r) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidHex(s string) bool {
|
||||
if len(s) == 0 {
|
||||
return false
|
||||
}
|
||||
for i := range s {
|
||||
c := s[i]
|
||||
if c < '0' || c > '9' && c < 'a' || c > 'f' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
83
x/model/digest_test.go
Normal file
83
x/model/digest_test.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package model
|
||||
|
||||
import "testing"
|
||||
|
||||
// - test scan
|
||||
// - test marshal text
|
||||
// - test unmarshal text
|
||||
// - test log value
|
||||
// - test string
|
||||
// - test type
|
||||
// - test digest
|
||||
// - test valid
|
||||
// - test driver valuer
|
||||
// - test sql scanner
|
||||
// - test parse digest
|
||||
|
||||
var testDigests = map[string]Digest{
|
||||
"": {},
|
||||
"sha256-1234": {s: "sha256-1234"},
|
||||
"sha256-5678": {s: "sha256-5678"},
|
||||
"blake2-9abc": {s: "blake2-9abc"},
|
||||
"-1234": {},
|
||||
"sha256-": {},
|
||||
"sha256-1234-5678": {},
|
||||
"sha256-P": {}, // invalid hex
|
||||
"sha256-1234P": {},
|
||||
"---": {},
|
||||
}
|
||||
|
||||
func TestDigestParse(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, want := range testDigests {
|
||||
got := ParseDigest(s)
|
||||
t.Logf("ParseDigest(%q) = %#v", s, got)
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestString(t *testing.T) {
|
||||
// Test cases.
|
||||
for s, d := range testDigests {
|
||||
want := s
|
||||
if !d.IsValid() {
|
||||
want = ""
|
||||
}
|
||||
got := d.String()
|
||||
if got != want {
|
||||
t.Errorf("ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
got = ParseDigest(s).String()
|
||||
if got != want {
|
||||
t.Errorf("roundtrip ParseDigest(%q).String() = %q; want %q", s, got, want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestDigestUnmarshalText(t *testing.T) {
|
||||
const testDigest = "sha256-1234"
|
||||
t.Run("UnmarshalText (into Valid)", func(t *testing.T) {
|
||||
d := ParseDigest(testDigest)
|
||||
if !d.IsValid() {
|
||||
panic("invalid test")
|
||||
}
|
||||
if err := d.UnmarshalText(nil); err == nil {
|
||||
t.Errorf("UnmarshalText on valid Digest did not return error")
|
||||
}
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText on valid Digest changed Digest: %q", d.String())
|
||||
}
|
||||
})
|
||||
t.Run("UnmarshalText make safe copy", func(t *testing.T) {
|
||||
data := []byte(testDigest)
|
||||
var d Digest
|
||||
d.UnmarshalText(data)
|
||||
data[0] = 'x'
|
||||
if d.String() != testDigest {
|
||||
t.Errorf("UnmarshalText did not make a safe copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
132
x/model/file.go
Normal file
132
x/model/file.go
Normal file
@@ -0,0 +1,132 @@
|
||||
// Package model implements the File and Name types for working with and
|
||||
// representing Modelfiles and model Names.
|
||||
//
|
||||
// The Name type should be used when working with model names, and the File
|
||||
// type should be used when working with Modelfiles.
|
||||
package model
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"io"
|
||||
"iter"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ParamPragma struct {
|
||||
Key string
|
||||
Value string
|
||||
}
|
||||
|
||||
type MessagePragma struct {
|
||||
Role string
|
||||
Content string
|
||||
}
|
||||
|
||||
type File struct {
|
||||
// From is a required pragma that specifies the source of the model,
|
||||
// either on disk, or by reference (see model.ParseName).
|
||||
From string
|
||||
|
||||
// Optional
|
||||
Params []ParamPragma
|
||||
Template string
|
||||
System string
|
||||
Adapter string
|
||||
Messages []MessagePragma
|
||||
|
||||
License string
|
||||
}
|
||||
|
||||
type FileError struct {
|
||||
Pragma string
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *FileError) Error() string {
|
||||
return e.Pragma + ": " + e.Message
|
||||
}
|
||||
|
||||
// Pragma represents a single pragma in a Modelfile.
|
||||
type Pragma struct {
|
||||
// The pragma name
|
||||
Name string
|
||||
|
||||
// Args contains the user-defined arguments for the pragma. If no
|
||||
// arguments were provided, it is nil.
|
||||
Args []string
|
||||
}
|
||||
|
||||
func (p Pragma) Arg(i int) string {
|
||||
if i >= len(p.Args) {
|
||||
return ""
|
||||
}
|
||||
return p.Args[i]
|
||||
}
|
||||
|
||||
func FilePragmas(r io.Reader) iter.Seq2[Pragma, error] {
|
||||
return func(yield func(Pragma, error) bool) {
|
||||
sc := bufio.NewScanner(r)
|
||||
for sc.Scan() {
|
||||
line := sc.Text()
|
||||
|
||||
// TODO(bmizerany): set a max num fields/args to
|
||||
// prevent mem bloat
|
||||
args := strings.Fields(line)
|
||||
if len(args) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
p := Pragma{
|
||||
Name: strings.ToUpper(args[0]),
|
||||
}
|
||||
if p.Name == "MESSAGE" {
|
||||
// handle special case where message content
|
||||
// is space separated on the _rest_ of the
|
||||
// line like: `MESSAGE user Is Ontario in
|
||||
// Canada?`
|
||||
panic("TODO")
|
||||
}
|
||||
if len(args) > 1 {
|
||||
p.Args = args[1:]
|
||||
}
|
||||
if !yield(p, nil) {
|
||||
return
|
||||
}
|
||||
}
|
||||
if sc.Err() != nil {
|
||||
yield(Pragma{}, sc.Err())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func ParseFile(r io.Reader) (File, error) {
|
||||
var f File
|
||||
for p, err := range FilePragmas(r) {
|
||||
if err != nil {
|
||||
return File{}, err
|
||||
}
|
||||
switch p.Name {
|
||||
case "FROM":
|
||||
f.From = p.Arg(0)
|
||||
case "PARAMETER":
|
||||
f.Params = append(f.Params, ParamPragma{
|
||||
Key: strings.ToLower(p.Arg(0)),
|
||||
Value: p.Arg(1),
|
||||
})
|
||||
case "TEMPLATE":
|
||||
f.Template = p.Arg(0)
|
||||
case "SYSTEM":
|
||||
f.System = p.Arg(0)
|
||||
case "ADAPTER":
|
||||
f.Adapter = p.Arg(0)
|
||||
case "MESSAGE":
|
||||
f.Messages = append(f.Messages, MessagePragma{
|
||||
Role: p.Arg(0),
|
||||
Content: p.Arg(1),
|
||||
})
|
||||
case "LICENSE":
|
||||
f.License = p.Arg(0)
|
||||
}
|
||||
}
|
||||
return f, nil
|
||||
}
|
||||
593
x/model/name.go
Normal file
593
x/model/name.go
Normal file
@@ -0,0 +1,593 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"errors"
|
||||
"hash/maphash"
|
||||
"io"
|
||||
"iter"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/x/types/structs"
|
||||
)
|
||||
|
||||
// Errors
|
||||
var (
|
||||
// ErrInvalidName is not used by this package, but is exported so that
|
||||
// other packages do not need to invent their own error type when they
|
||||
// need to return an error for an invalid name.
|
||||
ErrIncompleteName = errors.New("incomplete model name")
|
||||
ErrInvalidDigest = errors.New("invalid digest")
|
||||
)
|
||||
|
||||
const MaxNamePartLen = 128
|
||||
|
||||
type PartKind int
|
||||
|
||||
// Levels of concreteness
|
||||
const (
|
||||
// Each value aligns with its index in the Name.parts array.
|
||||
|
||||
PartHost PartKind = iota
|
||||
PartNamespace
|
||||
PartModel
|
||||
PartTag
|
||||
PartBuild
|
||||
PartDigest
|
||||
|
||||
// Invalid is a special part that is used to indicate that a part is
|
||||
// invalid. It is not a valid part of a Name.
|
||||
//
|
||||
// It should be kept as the last part in the list.
|
||||
PartInvalid
|
||||
)
|
||||
|
||||
var kindNames = map[PartKind]string{
|
||||
PartHost: "Host",
|
||||
PartNamespace: "Namespace",
|
||||
PartModel: "Name",
|
||||
PartTag: "Tag",
|
||||
PartBuild: "Build",
|
||||
PartDigest: "Digest",
|
||||
PartInvalid: "Invalid",
|
||||
}
|
||||
|
||||
func (k PartKind) String() string {
|
||||
return cmp.Or(kindNames[k], "Unknown")
|
||||
}
|
||||
|
||||
// Name is an opaque reference to a model. It holds the parts of a model
|
||||
// with the case preserved, but is not directly comparable with other Names
|
||||
// since model names can be represented with different caseing depending on
|
||||
// the use case. For instance, "Mistral" and "mistral" are the same model
|
||||
// but each version may have come from different sources (e.g. copied from a
|
||||
// Web page, or from a file path).
|
||||
//
|
||||
// Valid Names can ONLY be constructed by calling [ParseName].
|
||||
//
|
||||
// A Name is valid if and only if is have a valid Model part. The other parts
|
||||
// are optional.
|
||||
//
|
||||
// A Name is considered "complete" if it has all parts present. To check if a
|
||||
// Name is complete, use [Name.IsComplete].
|
||||
//
|
||||
// To compare two names in a case-insensitive manner, use [Name.EqualFold].
|
||||
//
|
||||
// The parts of a Name are:
|
||||
//
|
||||
// - Host: the domain of the model (optional)
|
||||
// - Namespace: the namespace of the model (optional)
|
||||
// - Model: the name of the model (required)
|
||||
// - Tag: the tag of the model (optional)
|
||||
// - Build: the build of the model; usually the quantization or "file type" (optional)
|
||||
//
|
||||
// The parts can be obtained in their original form by calling [Name.Parts].
|
||||
//
|
||||
// To check if a Name has at minimum a valid model part, use [Name.IsValid].
|
||||
//
|
||||
// To make a Name by filling in missing parts from another Name, use [Fill].
|
||||
type Name struct {
|
||||
_ structs.Incomparable
|
||||
parts [6]string // host, namespace, model, tag, build
|
||||
|
||||
// TODO(bmizerany): track offsets and hold s (raw string) here? We
|
||||
// could pack the offests all into a single uint64 since the first
|
||||
// parts take less bits since their max offset is less than the max
|
||||
// offset of the next part. This would save a ton of bytes per Name
|
||||
// and mean zero allocations for String.
|
||||
}
|
||||
|
||||
// ParseName parses s into a Name. The input string must be a valid string
|
||||
// representation of a model name in the form:
|
||||
//
|
||||
// [host/][namespace/]<model>[:tag][+build][@<digest-type>-<digest>]
|
||||
//
|
||||
// The name part is required, all others are optional. If a part is missing,
|
||||
// it is left empty in the returned Name. If a part is invalid, the zero Ref
|
||||
// value is returned.
|
||||
//
|
||||
// The build part is normalized to uppercase.
|
||||
//
|
||||
// Examples of valid paths:
|
||||
//
|
||||
// "example.com/library/mistral:7b+x"
|
||||
// "example.com/eva/mistral:7b+Q4_0"
|
||||
// "mistral:7b+x"
|
||||
// "example.com/mike/mistral:latest+Q4_0"
|
||||
// "example.com/bruce/mistral:latest"
|
||||
// "example.com/mistral:7b+Q4_0@sha256-1234567890abcdef"
|
||||
//
|
||||
// Examples of invalid paths:
|
||||
//
|
||||
// "example.com/mistral:7b+"
|
||||
// "example.com/mistral:7b+Q4_0+"
|
||||
// "x/y/z/z:8n+I"
|
||||
// ""
|
||||
//
|
||||
// It returns the zero value if any part is invalid.
|
||||
//
|
||||
// As a rule of thumb, an valid name is one that can be round-tripped with
|
||||
// the [Name.String] method. That means ("x+") is invalid because
|
||||
// [Name.String] will not print a "+" if the build is empty.
|
||||
func ParseName(s string) Name {
|
||||
var r Name
|
||||
for kind, part := range Parts(s) {
|
||||
if kind == PartInvalid {
|
||||
return Name{}
|
||||
}
|
||||
if kind == PartDigest && !ParseDigest(part).IsValid() {
|
||||
return Name{}
|
||||
}
|
||||
r.parts[kind] = part
|
||||
}
|
||||
if r.IsValid() || r.IsResolved() {
|
||||
return r
|
||||
}
|
||||
return Name{}
|
||||
}
|
||||
|
||||
func MustParseName(s string) Name {
|
||||
r := ParseName(s)
|
||||
if !r.IsValid() {
|
||||
panic("model.MustParseName: invalid name: " + s)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Fill fills in the missing parts of dst with the parts of src.
|
||||
//
|
||||
// The returned Name will only be valid if dst is valid.
|
||||
func Fill(dst, src Name) Name {
|
||||
var r Name
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(dst.parts[i], src.parts[i])
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// WithBuild returns a copy of r with the build set to the given string.
|
||||
func (r Name) WithBuild(build string) Name {
|
||||
r.parts[PartBuild] = build
|
||||
return r
|
||||
}
|
||||
|
||||
func (r Name) WithDigest(digest Digest) Name {
|
||||
r.parts[PartDigest] = digest.String()
|
||||
return r
|
||||
}
|
||||
|
||||
var mapHashSeed = maphash.MakeSeed()
|
||||
|
||||
// MapHash returns a case insensitive hash for use in maps and equality
|
||||
// checks. For a convienent way to compare names, use [Name.EqualFold].
|
||||
func (r Name) MapHash() uint64 {
|
||||
// correctly hash the parts with case insensitive comparison
|
||||
var h maphash.Hash
|
||||
h.SetSeed(mapHashSeed)
|
||||
for _, part := range r.Parts() {
|
||||
// downcase the part for hashing
|
||||
for i := range part {
|
||||
c := part[i]
|
||||
if c >= 'A' && c <= 'Z' {
|
||||
c = c - 'A' + 'a'
|
||||
}
|
||||
h.WriteByte(c)
|
||||
}
|
||||
}
|
||||
return h.Sum64()
|
||||
}
|
||||
|
||||
func (r Name) slice(from, to PartKind) Name {
|
||||
var v Name
|
||||
copy(v.parts[from:to+1], r.parts[from:to+1])
|
||||
return v
|
||||
}
|
||||
|
||||
// DisplayModel returns the a display string composed of the model only.
|
||||
func (r Name) DisplayModel() string {
|
||||
return r.parts[PartModel]
|
||||
}
|
||||
|
||||
// DisplayFullest returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// It does not include the build part. For the fullest possible display
|
||||
// string with the build, use [Name.String].
|
||||
func (r Name) DisplayFullest() string {
|
||||
return r.slice(PartHost, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayShort returns the fullest possible display string in form:
|
||||
//
|
||||
// <model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayShort() string {
|
||||
return r.slice(PartModel, PartTag).String()
|
||||
}
|
||||
|
||||
// DisplayLong returns the fullest possible display string in form:
|
||||
//
|
||||
// <namespace>/<model>:<tag>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
func (r Name) DisplayLong() string {
|
||||
return r.slice(PartNamespace, PartTag).String()
|
||||
}
|
||||
|
||||
var seps = [...]string{
|
||||
PartHost: "/",
|
||||
PartNamespace: "/",
|
||||
PartModel: ":",
|
||||
PartTag: "+",
|
||||
PartBuild: "@",
|
||||
PartDigest: "",
|
||||
}
|
||||
|
||||
// WriteTo implements io.WriterTo. It writes the fullest possible display
|
||||
// string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>@<digest-type>-<digest>
|
||||
//
|
||||
// Missing parts and their seperators are not written.
|
||||
//
|
||||
// The full digest is always prefixed with "@". That is if [Name.IsValid]
|
||||
// reports false and [Name.IsResolved] reports true, then the string is
|
||||
// returned as "@<digest-type>-<digest>".
|
||||
func (r Name) writeTo(w io.StringWriter) {
|
||||
var partsWritten int
|
||||
for i := range r.parts {
|
||||
if r.parts[i] == "" {
|
||||
continue
|
||||
}
|
||||
if partsWritten > 0 || i == int(PartDigest) {
|
||||
w.WriteString(seps[i-1])
|
||||
}
|
||||
w.WriteString(r.parts[i])
|
||||
partsWritten++
|
||||
}
|
||||
}
|
||||
|
||||
var builderPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return &strings.Builder{}
|
||||
},
|
||||
}
|
||||
|
||||
// String returns the fullest possible display string in form:
|
||||
//
|
||||
// <host>/<namespace>/<model>:<tag>+<build>
|
||||
//
|
||||
// If any part is missing, it is omitted from the display string.
|
||||
//
|
||||
// For the fullest possible display string without the build, use
|
||||
// [Name.DisplayFullest].
|
||||
func (r Name) String() string {
|
||||
b := builderPool.Get().(*strings.Builder)
|
||||
defer builderPool.Put(b)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
r.writeTo(b)
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// GoString implements fmt.GoStringer. It returns a string suitable for
|
||||
// debugging and logging. It is similar to [Name.String] but it always
|
||||
// returns a string that includes all parts of the Name, with missing parts
|
||||
// replaced with a ("?").
|
||||
func (r Name) GoString() string {
|
||||
for i := range r.parts {
|
||||
r.parts[i] = cmp.Or(r.parts[i], "?")
|
||||
}
|
||||
return r.String()
|
||||
}
|
||||
|
||||
// LogValue implements slog.Valuer.
|
||||
func (r Name) LogValue() slog.Value {
|
||||
return slog.StringValue(r.GoString())
|
||||
}
|
||||
|
||||
var bufPool = sync.Pool{
|
||||
New: func() interface{} {
|
||||
return new(bytes.Buffer)
|
||||
},
|
||||
}
|
||||
|
||||
// MarshalText implements [encoding.TextMarshaler].
|
||||
func (r Name) MarshalText() ([]byte, error) {
|
||||
b := bufPool.Get().(*bytes.Buffer)
|
||||
b.Reset()
|
||||
b.Grow(50) // arbitrarily long enough for most names
|
||||
defer bufPool.Put(b)
|
||||
r.writeTo(b)
|
||||
// TODO: We can remove this alloc if/when
|
||||
// https://github.com/golang/go/issues/62384 lands.
|
||||
return b.Bytes(), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements [encoding.TextUnmarshaler].
|
||||
//
|
||||
// It is an error to call UnmarshalText on a valid Name.
|
||||
func (r *Name) UnmarshalText(text []byte) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of UnmarshalText is that it should only be
|
||||
// called on an invalid/zero Name. If we allow UnmarshalText
|
||||
// on a valid Name, then the Name will be mutated, breaking
|
||||
// the immutability of the Name.
|
||||
return errors.New("model.Name: illegal UnmarshalText on valid Name")
|
||||
}
|
||||
|
||||
// The contract of UnmarshalText is that we copy to keep the text.
|
||||
*r = ParseName(string(text))
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ driver.Valuer = Name{}
|
||||
_ sql.Scanner = (*Name)(nil)
|
||||
)
|
||||
|
||||
// Scan implements [database/sql.Scanner].
|
||||
func (r *Name) Scan(src any) error {
|
||||
if r.IsValid() {
|
||||
// The invariant of Scan is that it should only be called on an
|
||||
// invalid/zero Name. If we allow Scan on a valid Name, then the
|
||||
// Name will be mutated, breaking the immutability of the Name.
|
||||
return errors.New("model.Name: illegal Scan on valid Name")
|
||||
|
||||
}
|
||||
switch v := src.(type) {
|
||||
case string:
|
||||
*r = ParseName(v)
|
||||
return nil
|
||||
case []byte:
|
||||
*r = ParseName(string(v))
|
||||
return nil
|
||||
}
|
||||
return errors.New("model.Name: invalid Scan source")
|
||||
}
|
||||
|
||||
// Value implements [database/sql/driver.Valuer].
|
||||
func (r Name) Value() (driver.Value, error) {
|
||||
return r.String(), nil
|
||||
}
|
||||
|
||||
// IsComplete reports whether the Name is fully qualified. That is it has a
|
||||
// domain, namespace, name, tag, and build.
|
||||
func (r Name) IsComplete() bool {
|
||||
return !slices.Contains(r.parts[:PartDigest], "")
|
||||
}
|
||||
|
||||
// IsCompleteNoBuild is like [Name.IsComplete] but it does not require the
|
||||
// build part to be present.
|
||||
func (r Name) IsCompleteNoBuild() bool {
|
||||
return !slices.Contains(r.parts[:PartBuild], "")
|
||||
}
|
||||
|
||||
// IsResolved reports true if the Name has a valid digest.
|
||||
//
|
||||
// It is possible to have a valid Name, or a complete Name that is not
|
||||
// resolved.
|
||||
func (r Name) IsResolved() bool {
|
||||
return r.Digest().IsValid()
|
||||
}
|
||||
|
||||
// Digest returns the digest part of the Name, if any.
|
||||
//
|
||||
// If Digest returns a non-empty string, then [Name.IsResolved] will return
|
||||
// true, and digest is considered valid.
|
||||
func (r Name) Digest() Digest {
|
||||
// This was already validated by ParseName, so we can just return it.
|
||||
return Digest{r.parts[PartDigest]}
|
||||
}
|
||||
|
||||
// EqualFold reports whether r and o are equivalent model names, ignoring
|
||||
// case.
|
||||
func (r Name) EqualFold(o Name) bool {
|
||||
return r.CompareFold(o) == 0
|
||||
}
|
||||
|
||||
// CompareFold performs a case-insensitive cmp.Compare on r and o.
|
||||
//
|
||||
// This can be used with [slices.SortFunc].
|
||||
//
|
||||
// For simple equality checks, use [Name.EqualFold].
|
||||
func (r Name) CompareFold(o Name) int {
|
||||
return slices.CompareFunc(r.parts[:], o.parts[:], compareFold)
|
||||
}
|
||||
|
||||
func compareFold(a, b string) int {
|
||||
return slices.CompareFunc([]rune(a), []rune(b), func(a, b rune) int {
|
||||
return cmp.Compare(downcase(a), downcase(b))
|
||||
})
|
||||
}
|
||||
|
||||
func downcase(r rune) rune {
|
||||
if r >= 'A' && r <= 'Z' {
|
||||
return r - 'A' + 'a'
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// TODO(bmizerany): driver.Value? (MarshalText etc should be enough)
|
||||
|
||||
// Parts returns the parts of the Name in order of concreteness.
|
||||
//
|
||||
// The length of the returned slice is always 5.
|
||||
func (r Name) Parts() []string {
|
||||
return slices.Clone(r.parts[:])
|
||||
}
|
||||
|
||||
// Parts returns a sequence of the parts of a Name string from most specific
|
||||
// to least specific.
|
||||
//
|
||||
// It normalizes the input string by removing "http://" and "https://" only.
|
||||
// No other normalization is done.
|
||||
func Parts(s string) iter.Seq2[PartKind, string] {
|
||||
return func(yield func(PartKind, string) bool) {
|
||||
if strings.HasPrefix(s, "http://") {
|
||||
s = s[len("http://"):]
|
||||
}
|
||||
if strings.HasPrefix(s, "https://") {
|
||||
s = s[len("https://"):]
|
||||
}
|
||||
|
||||
if len(s) > MaxNamePartLen || len(s) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
yieldValid := func(kind PartKind, part string) bool {
|
||||
if !isValidPart(kind, part) {
|
||||
yield(PartInvalid, "")
|
||||
return false
|
||||
}
|
||||
return yield(kind, part)
|
||||
}
|
||||
|
||||
partLen := 0
|
||||
state, j := PartDigest, len(s)
|
||||
for i := len(s) - 1; i >= 0; i-- {
|
||||
if partLen++; partLen > MaxNamePartLen {
|
||||
// catch a part that is too long early, so
|
||||
// we don't keep spinning on it, waiting for
|
||||
// an isInValidPart check which would scan
|
||||
// over it again.
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
|
||||
switch s[i] {
|
||||
case '@':
|
||||
switch state {
|
||||
case PartDigest:
|
||||
if !yieldValid(PartDigest, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
if i == 0 {
|
||||
// This is the form
|
||||
// "@<digest>" which is valid.
|
||||
//
|
||||
// We're done.
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartBuild, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '+':
|
||||
switch state {
|
||||
case PartBuild, PartDigest:
|
||||
if !yieldValid(PartBuild, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartTag, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case ':':
|
||||
switch state {
|
||||
case PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartTag, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartModel, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
case '/':
|
||||
switch state {
|
||||
case PartModel, PartTag, PartBuild, PartDigest:
|
||||
if !yieldValid(PartModel, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j = PartNamespace, i
|
||||
case PartNamespace:
|
||||
if !yieldValid(PartNamespace, s[i+1:j]) {
|
||||
return
|
||||
}
|
||||
state, j, partLen = PartHost, i, 0
|
||||
default:
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
default:
|
||||
if !isValidByte(state, s[i]) {
|
||||
yield(PartInvalid, "")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state <= PartNamespace {
|
||||
yieldValid(state, s[:j])
|
||||
} else {
|
||||
yieldValid(PartModel, s[:j])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// IsValid returns true if the Name hPartas a valid nick. To know if a Name is
|
||||
// "complete", use [Name.IsComplete].
|
||||
func (r Name) IsValid() bool {
|
||||
// Parts ensures we only have valid parts, so no need to validate
|
||||
// them here, only check if we have a name or not.
|
||||
return r.parts[PartModel] != ""
|
||||
}
|
||||
|
||||
// isValidPart returns Parttrue if given part is valid ascii [a-zA-Z0-9_\.-]
|
||||
func isValidPart(kind PartKind, s string) bool {
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
for _, c := range []byte(s) {
|
||||
if !isValidByte(kind, c) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func isValidByte(kind PartKind, c byte) bool {
|
||||
if kind == PartNamespace && c == '.' {
|
||||
return false
|
||||
}
|
||||
if c == '.' || c == '-' {
|
||||
return true
|
||||
}
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' || c >= '0' && c <= '9' || c == '_' {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
572
x/model/name_test.go
Normal file
572
x/model/name_test.go
Normal file
@@ -0,0 +1,572 @@
|
||||
package model
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"cmp"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type fields struct {
|
||||
host, namespace, model, tag, build string
|
||||
digest string
|
||||
}
|
||||
|
||||
func fieldsFromName(p Name) fields {
|
||||
return fields{
|
||||
host: p.parts[PartHost],
|
||||
namespace: p.parts[PartNamespace],
|
||||
model: p.parts[PartModel],
|
||||
tag: p.parts[PartTag],
|
||||
build: p.parts[PartBuild],
|
||||
digest: p.parts[PartDigest],
|
||||
}
|
||||
}
|
||||
|
||||
var testNames = map[string]fields{
|
||||
"mistral:latest": {model: "mistral", tag: "latest"},
|
||||
"mistral": {model: "mistral"},
|
||||
"mistral:30B": {model: "mistral", tag: "30B"},
|
||||
"mistral:7b": {model: "mistral", tag: "7b"},
|
||||
"mistral:7b+Q4_0": {model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"mistral+KQED": {model: "mistral", build: "KQED"},
|
||||
"mistral.x-3:7b+Q4_0": {model: "mistral.x-3", tag: "7b", build: "Q4_0"},
|
||||
"mistral:7b+q4_0": {model: "mistral", tag: "7b", build: "q4_0"},
|
||||
"llama2": {model: "llama2"},
|
||||
"user/model": {namespace: "user", model: "model"},
|
||||
"example.com/ns/mistral:7b+Q4_0": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "Q4_0"},
|
||||
"example.com/ns/mistral:7b+X": {host: "example.com", namespace: "ns", model: "mistral", tag: "7b", build: "X"},
|
||||
|
||||
// invalid digest
|
||||
"mistral:latest@invalid256-": {},
|
||||
"mistral:latest@-123": {},
|
||||
"mistral:latest@!-123": {},
|
||||
"mistral:latest@1-!": {},
|
||||
"mistral:latest@": {},
|
||||
|
||||
// resolved
|
||||
"x@sha123-1": {model: "x", digest: "sha123-1"},
|
||||
"@sha456-2": {digest: "sha456-2"},
|
||||
|
||||
"@@sha123-1": {},
|
||||
|
||||
// preserves case for build
|
||||
"x+b": {model: "x", build: "b"},
|
||||
|
||||
// invalid (includes fuzzing trophies)
|
||||
" / / : + ": {},
|
||||
" / : + ": {},
|
||||
" : + ": {},
|
||||
" + ": {},
|
||||
" : ": {},
|
||||
" / ": {},
|
||||
" /": {},
|
||||
"/ ": {},
|
||||
"/": {},
|
||||
":": {},
|
||||
"+": {},
|
||||
|
||||
// (".") in namepsace is not allowed
|
||||
"invalid.com/7b+x": {},
|
||||
|
||||
"invalid:7b+Q4_0:latest": {},
|
||||
"in valid": {},
|
||||
"invalid/y/z/foo": {},
|
||||
"/0": {},
|
||||
"0 /0": {},
|
||||
"0 /": {},
|
||||
"0/": {},
|
||||
":/0": {},
|
||||
"+0/00000": {},
|
||||
"0+.\xf2\x80\xf6\x9d00000\xe5\x99\xe6\xd900\xd90\xa60\x91\xdc0\xff\xbf\x99\xe800\xb9\xdc\xd6\xc300\x970\xfb\xfd0\xe0\x8a\xe1\xad\xd40\x9700\xa80\x980\xdd0000\xb00\x91000\xfe0\x89\x9b\x90\x93\x9f0\xe60\xf7\x84\xb0\x87\xa5\xff0\xa000\x9a\x85\xf6\x85\xfe\xa9\xf9\xe9\xde00\xf4\xe0\x8f\x81\xad\xde00\xd700\xaa\xe000000\xb1\xee0\x91": {},
|
||||
"0//0": {},
|
||||
"m+^^^": {},
|
||||
"file:///etc/passwd": {},
|
||||
"file:///etc/passwd:latest": {},
|
||||
"file:///etc/passwd:latest+u": {},
|
||||
|
||||
strings.Repeat("a", MaxNamePartLen): {model: strings.Repeat("a", MaxNamePartLen)},
|
||||
strings.Repeat("a", MaxNamePartLen+1): {},
|
||||
}
|
||||
|
||||
func TestNameParts(t *testing.T) {
|
||||
var p Name
|
||||
if w, g := int(PartDigest+1), len(p.Parts()); w != g {
|
||||
t.Errorf("Parts() = %d; want %d", g, w)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamePartString(t *testing.T) {
|
||||
if g := PartKind(-2).String(); g != "Unknown" {
|
||||
t.Errorf("Unknown part = %q; want %q", g, "Unknown")
|
||||
}
|
||||
for kind, name := range kindNames {
|
||||
if g := kind.String(); g != name {
|
||||
t.Errorf("%s = %q; want %q", kind, g, name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseName(t *testing.T) {
|
||||
for baseName, want := range testNames {
|
||||
for _, prefix := range []string{"", "https://", "http://"} {
|
||||
// We should get the same results with or without the
|
||||
// http(s) prefixes
|
||||
s := prefix + baseName
|
||||
|
||||
t.Run(s, func(t *testing.T) {
|
||||
for kind, part := range Parts(s) {
|
||||
t.Logf("Part: %s: %q", kind, part)
|
||||
}
|
||||
|
||||
name := ParseName(s)
|
||||
got := fieldsFromName(name)
|
||||
if got != want {
|
||||
t.Errorf("ParseName(%q) = %q; want %q", s, got, want)
|
||||
}
|
||||
|
||||
// test round-trip
|
||||
if !ParseName(name.String()).EqualFold(name) {
|
||||
t.Errorf("ParseName(%q).String() = %s; want %s", s, name.String(), baseName)
|
||||
}
|
||||
|
||||
if name.IsValid() && name.DisplayModel() == "" {
|
||||
t.Errorf("Valid() = true; Model() = %q; want non-empty name", got.model)
|
||||
} else if !name.IsValid() && name.DisplayModel() != "" {
|
||||
t.Errorf("Valid() = false; Model() = %q; want empty name", got.model)
|
||||
}
|
||||
|
||||
if name.IsResolved() && !name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = true; Digest() = %q; want non-empty digest", got.digest)
|
||||
} else if !name.IsResolved() && name.Digest().IsValid() {
|
||||
t.Errorf("Resolved() = false; Digest() = %q; want empty digest", got.digest)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompleteWithAndWithoutBuild(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
complete bool
|
||||
completeNoBuild bool
|
||||
}{
|
||||
{"", false, false},
|
||||
{"incomplete/mistral:7b+x", false, false},
|
||||
{"incomplete/mistral:7b+Q4_0", false, false},
|
||||
{"incomplete:7b+x", false, false},
|
||||
{"complete.com/x/mistral:latest+Q4_0", true, true},
|
||||
{"complete.com/x/mistral:latest", false, true},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
t.Logf("ParseName(%q) = %#v", tt.in, p)
|
||||
if g := p.IsComplete(); g != tt.complete {
|
||||
t.Errorf("Complete(%q) = %v; want %v", tt.in, g, tt.complete)
|
||||
}
|
||||
if g := p.IsCompleteNoBuild(); g != tt.completeNoBuild {
|
||||
t.Errorf("CompleteNoBuild(%q) = %v; want %v", tt.in, g, tt.completeNoBuild)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Complete uses Parts which returns a slice, but it should be
|
||||
// inlined when used in Complete, preventing any allocations or
|
||||
// escaping to the heap.
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("complete.com/x/mistral:latest+Q4_0").IsComplete())
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("Complete allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameLogValue(t *testing.T) {
|
||||
cases := []string{
|
||||
"example.com/library/mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
"mistral:7b+Q4_0",
|
||||
}
|
||||
for _, s := range cases {
|
||||
t.Run(s, func(t *testing.T) {
|
||||
var b bytes.Buffer
|
||||
log := slog.New(slog.NewTextHandler(&b, nil))
|
||||
name := ParseName(s)
|
||||
log.Info("", "name", name)
|
||||
want := fmt.Sprintf("name=%s", name.GoString())
|
||||
got := b.String()
|
||||
if !strings.Contains(got, want) {
|
||||
t.Errorf("expected log output to contain %q; got %q", want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameDisplay(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
in string
|
||||
wantShort string
|
||||
wantLong string
|
||||
wantComplete string
|
||||
wantString string
|
||||
wantModel string
|
||||
wantGoString string // default is tt.in
|
||||
}{
|
||||
{
|
||||
name: "Complete Name",
|
||||
in: "example.com/library/mistral:latest+Q4_0",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "example.com/library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "example.com/library/mistral:latest+Q4_0@?",
|
||||
},
|
||||
{
|
||||
name: "Short Name",
|
||||
in: "mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "mistral:latest",
|
||||
wantComplete: "mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/?/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Long Name",
|
||||
in: "library/mistral:latest",
|
||||
wantShort: "mistral:latest",
|
||||
wantLong: "library/mistral:latest",
|
||||
wantComplete: "library/mistral:latest",
|
||||
wantModel: "mistral",
|
||||
wantGoString: "?/library/mistral:latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "Case Preserved",
|
||||
in: "Library/Mistral:Latest",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@?",
|
||||
},
|
||||
{
|
||||
name: "With digest",
|
||||
in: "Library/Mistral:Latest@sha256-123456",
|
||||
wantShort: "Mistral:Latest",
|
||||
wantLong: "Library/Mistral:Latest",
|
||||
wantComplete: "Library/Mistral:Latest",
|
||||
wantModel: "Mistral",
|
||||
wantGoString: "?/Library/Mistral:Latest+?@sha256-123456",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
if g := p.DisplayShort(); g != tt.wantShort {
|
||||
t.Errorf("DisplayShort = %q; want %q", g, tt.wantShort)
|
||||
}
|
||||
if g := p.DisplayLong(); g != tt.wantLong {
|
||||
t.Errorf("DisplayLong = %q; want %q", g, tt.wantLong)
|
||||
}
|
||||
if g := p.DisplayFullest(); g != tt.wantComplete {
|
||||
t.Errorf("DisplayFullest = %q; want %q", g, tt.wantComplete)
|
||||
}
|
||||
if g := p.String(); g != tt.in {
|
||||
t.Errorf("String(%q) = %q; want %q", tt.in, g, tt.in)
|
||||
}
|
||||
if g := p.DisplayModel(); g != tt.wantModel {
|
||||
t.Errorf("Model = %q; want %q", g, tt.wantModel)
|
||||
}
|
||||
|
||||
tt.wantGoString = cmp.Or(tt.wantGoString, tt.in)
|
||||
if g := fmt.Sprintf("%#v", p); g != tt.wantGoString {
|
||||
t.Errorf("GoString() = %q; want %q", g, tt.wantGoString)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseNameAllocs(t *testing.T) {
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
})
|
||||
if allocs > 0 {
|
||||
t.Errorf("ParseName allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkParseName(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
for range b.N {
|
||||
keep(ParseName("example.com/mistral:7b+Q4_0"))
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNameDisplay(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
r := ParseName("example.com/mistral:7b+Q4_0")
|
||||
b.Run("Short", func(b *testing.B) {
|
||||
for range b.N {
|
||||
keep(r.DisplayShort())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func FuzzParseName(f *testing.F) {
|
||||
f.Add("example.com/mistral:7b+Q4_0")
|
||||
f.Add("example.com/mistral:7b+q4_0")
|
||||
f.Add("example.com/mistral:7b+x")
|
||||
f.Add("x/y/z:8n+I")
|
||||
f.Fuzz(func(t *testing.T, s string) {
|
||||
r0 := ParseName(s)
|
||||
if !r0.IsValid() {
|
||||
if !r0.EqualFold(Name{}) {
|
||||
t.Errorf("expected invalid path to be zero value; got %#v", r0)
|
||||
}
|
||||
t.Skipf("invalid path: %q", s)
|
||||
}
|
||||
|
||||
for _, p := range r0.Parts() {
|
||||
if len(p) > MaxNamePartLen {
|
||||
t.Errorf("part too long: %q", p)
|
||||
}
|
||||
}
|
||||
|
||||
if !strings.EqualFold(r0.String(), s) {
|
||||
t.Errorf("String() did not round-trip with case insensitivity: %q\ngot = %q\nwant = %q", s, r0.String(), s)
|
||||
}
|
||||
|
||||
r1 := ParseName(r0.String())
|
||||
if !r0.EqualFold(r1) {
|
||||
t.Errorf("round-trip mismatch: %+v != %+v", r0, r1)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFill(t *testing.T) {
|
||||
cases := []struct {
|
||||
dst string
|
||||
src string
|
||||
want string
|
||||
}{
|
||||
{"mistral", "o.com/library/PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"o.com/library/mistral", "PLACEHOLDER:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
{"", "o.com/library/mistral:latest+Q4_0", "o.com/library/mistral:latest+Q4_0"},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.dst, func(t *testing.T) {
|
||||
r := Fill(ParseName(tt.dst), ParseName(tt.src))
|
||||
if r.String() != tt.want {
|
||||
t.Errorf("Fill(%q, %q) = %q; want %q", tt.dst, tt.src, r, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNameTextMarshal(t *testing.T) {
|
||||
cases := []struct {
|
||||
in string
|
||||
want string
|
||||
wantErr error
|
||||
}{
|
||||
{"example.com/mistral:latest+Q4_0", "", nil},
|
||||
{"mistral:latest+Q4_0", "mistral:latest+Q4_0", nil},
|
||||
{"mistral:latest", "mistral:latest", nil},
|
||||
{"mistral", "mistral", nil},
|
||||
{"mistral:7b", "mistral:7b", nil},
|
||||
{"example.com/library/mistral:latest+Q4_0", "example.com/library/mistral:latest+Q4_0", nil},
|
||||
}
|
||||
|
||||
for _, tt := range cases {
|
||||
t.Run(tt.in, func(t *testing.T) {
|
||||
p := ParseName(tt.in)
|
||||
got, err := p.MarshalText()
|
||||
if !errors.Is(err, tt.wantErr) {
|
||||
t.Fatalf("MarshalText() error = %v; want %v", err, tt.wantErr)
|
||||
}
|
||||
if string(got) != tt.want {
|
||||
t.Errorf("MarshalText() = %q; want %q", got, tt.want)
|
||||
}
|
||||
|
||||
var r Name
|
||||
if err := r.UnmarshalText(got); err != nil {
|
||||
t.Fatalf("UnmarshalText() error = %v; want nil", err)
|
||||
}
|
||||
if !r.EqualFold(p) {
|
||||
t.Errorf("UnmarshalText() = %q; want %q", r, p)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("UnmarshalText into valid Name", func(t *testing.T) {
|
||||
// UnmarshalText should not be called on a valid Name.
|
||||
p := MustParseName("x")
|
||||
if err := p.UnmarshalText([]byte("mistral:latest+Q4_0")); err == nil {
|
||||
t.Error("UnmarshalText() = nil; want error")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("TextMarshal allocs", func(t *testing.T) {
|
||||
var data []byte
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
if !name.IsComplete() {
|
||||
// sanity check
|
||||
panic("sanity check failed")
|
||||
}
|
||||
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
var err error
|
||||
data, err = name.MarshalText()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
t.Fatal("MarshalText() = 0; want non-zero")
|
||||
}
|
||||
})
|
||||
if allocs > 0 {
|
||||
// TODO: Update when/if this lands:
|
||||
// https://github.com/golang/go/issues/62384
|
||||
//
|
||||
// Currently, the best we can do is 1 alloc.
|
||||
t.Errorf("MarshalText allocs = %v; want <= 1", allocs)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("UnmarshalTest makes safe copy", func(t *testing.T) {
|
||||
// UnmarshalText should make a copy of the data.
|
||||
data := []byte("mistral:latest+Q4_0")
|
||||
p := Name{}
|
||||
if err := p.UnmarshalText(data); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
data[0] = 'x'
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("UnmarshalText() did not make a copy")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
t.Run("Scan for already valid Name", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err == nil {
|
||||
t.Error("Scan() = nil; want error")
|
||||
}
|
||||
})
|
||||
t.Run("Scan for invalid Name", func(t *testing.T) {
|
||||
p := Name{}
|
||||
if err := p.Scan("mistral:latest+Q4_0"); err != nil {
|
||||
t.Errorf("Scan() = %v; want nil", err)
|
||||
}
|
||||
if p.String() != "mistral:latest+Q4_0" {
|
||||
t.Errorf("String() = %q; want %q", p, "mistral:latest+Q4_0")
|
||||
}
|
||||
})
|
||||
t.Run("Value", func(t *testing.T) {
|
||||
p := MustParseName("x")
|
||||
if g, err := p.Value(); err != nil {
|
||||
t.Errorf("Value() error = %v; want nil", err)
|
||||
} else if g != "x" {
|
||||
t.Errorf("Value() = %q; want %q", g, "x")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestNameStringAllocs(t *testing.T) {
|
||||
name := ParseName("example.com/ns/mistral:latest+Q4_0")
|
||||
allocs := testing.AllocsPerRun(1000, func() {
|
||||
keep(name.String())
|
||||
})
|
||||
if allocs > 1 {
|
||||
t.Errorf("String allocs = %v; want 0", allocs)
|
||||
}
|
||||
}
|
||||
|
||||
func ExampleFill() {
|
||||
defaults := ParseName("registry.ollama.com/library/PLACEHOLDER:latest+Q4_0")
|
||||
r := Fill(ParseName("mistral"), defaults)
|
||||
fmt.Println(r)
|
||||
|
||||
// Output:
|
||||
// registry.ollama.com/library/mistral:latest+Q4_0
|
||||
}
|
||||
|
||||
func ExampleName_MapHash() {
|
||||
m := map[uint64]bool{}
|
||||
|
||||
// key 1
|
||||
m[ParseName("mistral:latest+q4").MapHash()] = true
|
||||
m[ParseName("miSTRal:latest+Q4").MapHash()] = true
|
||||
m[ParseName("mistral:LATest+Q4").MapHash()] = true
|
||||
|
||||
// key 2
|
||||
m[ParseName("mistral:LATest").MapHash()] = true
|
||||
|
||||
fmt.Println(len(m))
|
||||
// Output:
|
||||
// 2
|
||||
}
|
||||
|
||||
func ExampleName_CompareFold_sort() {
|
||||
names := []Name{
|
||||
ParseName("mistral:latest"),
|
||||
ParseName("mistRal:7b+q4"),
|
||||
ParseName("MIstral:7b"),
|
||||
}
|
||||
|
||||
slices.SortFunc(names, Name.CompareFold)
|
||||
|
||||
for _, n := range names {
|
||||
fmt.Println(n)
|
||||
}
|
||||
|
||||
// Output:
|
||||
// MIstral:7b
|
||||
// mistRal:7b+q4
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func ExampleName_completeAndResolved() {
|
||||
for _, s := range []string{
|
||||
"x/y/z:latest+q4_0@sha123-1",
|
||||
"x/y/z:latest+q4_0",
|
||||
"@sha123-1",
|
||||
} {
|
||||
p := ParseName(s)
|
||||
fmt.Printf("complete:%v resolved:%v digest:%s\n", p.IsComplete(), p.IsResolved(), p.Digest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// complete:true resolved:true digest:sha123-1
|
||||
// complete:true resolved:false digest:
|
||||
// complete:false resolved:true digest:sha123-1
|
||||
}
|
||||
|
||||
func ExampleName_DisplayFullest() {
|
||||
for _, s := range []string{
|
||||
"example.com/jmorganca/mistral:latest+Q4_0",
|
||||
"mistral:latest+Q4_0",
|
||||
"mistral:latest",
|
||||
} {
|
||||
fmt.Println(ParseName(s).DisplayFullest())
|
||||
}
|
||||
|
||||
// Output:
|
||||
// example.com/jmorganca/mistral:latest
|
||||
// mistral:latest
|
||||
// mistral:latest
|
||||
}
|
||||
|
||||
func keep[T any](v T) T { return v }
|
||||
2
x/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/1d43ee52085cb4aa
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("/0")
|
||||
2
x/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
2
x/model/testdata/fuzz/FuzzParseRef/27fd759314f0e6d6
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
go test fuzz v1
|
||||
string("0//0")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user