Compare commits

..

109 Commits

Author SHA1 Message Date
Jeffrey Morgan
09a6f76f4c fix error on ollama run with a non-existent model 2024-02-01 23:11:52 -08:00
Jeffrey Morgan
e135167484 Add multimodel support to ollama run in noninteractive mopde (#2317) 2024-02-01 21:33:06 -08:00
Jeffrey Morgan
38296ab352 clear previous images when submitting an image to ollama run (#2316) 2024-02-01 21:30:26 -08:00
Daniel Hiltgen
f43dea68d1 Merge pull request #2318 from dhiltgen/more_clean
Harden generate patching model
2024-02-01 20:41:29 -08:00
Daniel Hiltgen
e1f50377f4 Harden generate patching model
Only apply patches if we have any, and make sure to cleanup
every file we patched at the end to leave the tree clean
2024-02-01 19:34:36 -08:00
Jeffrey Morgan
7913104527 Improvements to ollama run for multimodal models (#2300) 2024-02-01 17:09:51 -08:00
Michael Yang
bfbf2f7cf7 Merge pull request #2296 from ollama/mxyng/img-tags
append image tags to user content
2024-02-01 13:16:59 -08:00
Michael Yang
fe3cbd014f Merge pull request #2298 from ollama/mxyng/debug-prompt
structured debug prompt
2024-02-01 13:16:49 -08:00
Michael Yang
3d6f48507a structured debug prompt 2024-02-01 11:56:28 -08:00
Michael Yang
f3761405c8 use image id 2024-02-01 11:52:42 -08:00
Michael Yang
e49dc9f3d8 fix tests 2024-02-01 11:48:11 -08:00
Michael Yang
d125510b4b remove image tags 2024-02-01 11:32:51 -08:00
Russell Canfield
1ca386aa9e Feature - Add Wingman Extension (#2313) 2024-02-01 11:16:24 -08:00
Michael Yang
fb56988014 account for image projection in token count 2024-02-01 09:50:48 -08:00
Michael Yang
d046bee790 use llm.ImageData for chat 2024-01-31 19:18:25 -08:00
Jeffrey Morgan
f11bf0740b use llm.ImageData 2024-01-31 19:13:48 -08:00
Michael Yang
8450bf66e6 trim images 2024-01-31 19:13:47 -08:00
Michael Yang
b4e11be8ef append image tags to user content 2024-01-31 19:13:10 -08:00
Bruce MacDonald
a896079705 preserve last system message from modelfile (#2289) 2024-01-31 21:45:01 -05:00
Michael Yang
583950c828 Merge pull request #2294 from ollama/mxyng/slog-source
update slog handler options
2024-01-31 15:29:11 -08:00
Michael Yang
8ac08a0eec update slog handler options
- consistent format by using text handler for debug and non-debug
- truncate source file to just the file name
2024-01-31 15:15:00 -08:00
Michael Yang
60f47be64c Merge pull request #2284 from ollama/mxyng/parse-raw
remove unnecessary parse raw
2024-01-31 09:40:48 -08:00
Daniel Hiltgen
6e56077ada Merge pull request #2263 from dhiltgen/bump_llamacpp
Bump llama.cpp to b1999
2024-01-31 08:39:41 -08:00
Hoang Nguyen
98ae9467bb Added MindMac to Community Integrations -> Web & Desktop section (#1957) 2024-01-31 07:48:37 -08:00
Richard Macarthy
b7a24af083 Add twinny vscode extension to Extensions and Plugins (#1950) 2024-01-31 06:25:06 -08:00
Michael Yang
c8b1f2369e remove unnecessary parse raw 2024-01-30 17:00:53 -08:00
Daniel Hiltgen
72b12c3be7 Bump llama.cpp to b1999
This requires an upstream change to support graceful termination,
carried as a patch.
2024-01-30 16:52:12 -08:00
Bruce MacDonald
0632dff3f8 trim chat prompt based on llm context size (#1963) 2024-01-30 15:59:29 -05:00
Maximilian Weber
509e2dec8a Update README.md (#2252)
Added - [Ollama for R - rollama](https://github.com/JBGruber/rollama) in Libraries in README.md
2024-01-30 11:56:51 -08:00
Daniel Hiltgen
78a48de804 Merge pull request #2256 from dhiltgen/container_logs
Add container hints for troubleshooting
2024-01-30 08:12:48 -08:00
Daniel Hiltgen
e7dbb00331 Add container hints for troubleshooting
Some users are new to containers and unsure where the server logs go
2024-01-29 08:53:41 -08:00
Jeffrey Morgan
2e06ed01d5 remove unknown CPPFLAGS option 2024-01-28 17:51:23 -08:00
Daniel Hiltgen
4072b5879b Merge pull request #2246 from dhiltgen/reject_cuda_without_avx
Don't disable GPUs on arm without AVX
2024-01-28 16:26:55 -08:00
Daniel Hiltgen
15562e887d Don't disable GPUs on arm without AVX
AVX is an x86 feature, so ARM should be excluded from
the check.
2024-01-28 15:22:38 -08:00
Jeffrey Morgan
f2245c7c77 print prompt with OLLAMA_DEBUG=1 (#2245) 2024-01-28 15:22:35 -08:00
Jeffrey Morgan
e4b9b72f2a Do not repeat system prompt for chat templating (#2241) 2024-01-28 14:15:56 -08:00
Daniel Hiltgen
311f8e0c3f Merge pull request #2243 from dhiltgen/harden_zero_gpus
Harden for zero detected GPUs
2024-01-28 13:30:44 -08:00
Daniel Hiltgen
f07f8b7a9e Harden for zero detected GPUs
At least with the ROCm libraries, its possible to have the library
present with zero GPUs.  This fix avoids a divide by zero bug in llm.go
when we try to calculate GPU memory with zero GPUs.
2024-01-28 13:13:10 -08:00
Daniel Hiltgen
e02ecfb6c8 Merge pull request #2116 from dhiltgen/cc_50_80
Add support for CUDA 5.0 cards
2024-01-27 10:28:38 -08:00
Daniel Hiltgen
c8059b4dcf Merge pull request #2224 from jaglinux/fix_rocm_get_version_message
ROCm: Correct the response string in rocm_get_version function
2024-01-27 07:29:32 -08:00
Jagadish Krishnamoorthy
59d87127f5 Update gpu_info_rocm.c 2024-01-26 22:08:27 -08:00
Patrick Devine
b5cf31b460 add keep_alive to generate/chat/embedding api endpoints (#2146) 2024-01-26 14:28:02 -08:00
Daniel Hiltgen
cc4915e262 Merge pull request #2214 from dhiltgen/reject_cuda_without_avx
Detect lack of AVX and fallback to CPU mode
2024-01-26 12:06:44 -08:00
Daniel Hiltgen
667a2ba18a Detect lack of AVX and fallback to CPU mode
We build the GPU libraries with AVX enabled to ensure that if not all
layers fit on the GPU we get better performance in a mixed mode.
If the user is using a virtualization/emulation system that lacks AVX
this used to result in an illegal instruction error and crash before this
fix.  Now we will report a warning in the server log, and just use
CPU mode to ensure we don't crash.
2024-01-26 11:36:03 -08:00
Michael Yang
e054ebe059 Merge pull request #2212 from ollama/mxyng/fix-build
fix build
2024-01-26 11:19:08 -08:00
Michael Yang
9d3dcfd0ec fix logging 2024-01-26 11:04:27 -08:00
Michael Yang
6e0ea5ecc8 Merge pull request #1916 from ollama/mxyng/inactivity-monitor
download: add inactivity monitor
2024-01-26 10:56:00 -08:00
Daniel Hiltgen
a47d8b2557 Merge pull request #2197 from dhiltgen/remove_rocm_image
Add back ROCm container support
2024-01-26 09:34:23 -08:00
Daniel Hiltgen
30c43c285c Merge pull request #2195 from dhiltgen/rocm_real_gpus
Ignore AMD integrated GPUs
2024-01-26 09:30:24 -08:00
Daniel Hiltgen
23a7ea593b Merge pull request #2209 from dhiltgen/harden_mgmt
Fix crash on cuda ml init failure
2024-01-26 09:30:13 -08:00
Daniel Hiltgen
75c44aa319 Add back ROCm container support
This adds ROCm support back as a discrete image.
2024-01-26 09:24:29 -08:00
Daniel Hiltgen
9d7b5d6c91 Ignore AMD integrated GPUs
Detect and ignore integrated GPUs reported by rocm.
2024-01-26 09:21:35 -08:00
Daniel Hiltgen
5d9c4a5f5a Fix crash on cuda ml init failure
The new driver lookup code was triggering after init failure due to a missing return
2024-01-26 09:18:33 -08:00
Daniel Hiltgen
197e420a97 Merge pull request #2196 from dhiltgen/remove_rocm_image
Switch back to ubuntu base
2024-01-25 16:50:32 -08:00
Daniel Hiltgen
a34e1ad3cf Switch back to ubuntu base
The size increase for rocm support in the standard image is problematic
We'll revisit multiple tags for rocm support in a follow up PR.
2024-01-25 16:46:01 -08:00
Michael Yang
2ae0556292 Merge pull request #1679 from ollama/mxyng/build-gpus
build cuda and rocm
2024-01-25 16:38:14 -08:00
Jeffrey Morgan
5be9bdd444 Update modelfile.md 2024-01-25 16:29:48 -08:00
Jeffrey Morgan
b706794905 Update modelfile.md to include MESSAGE 2024-01-25 16:29:32 -08:00
Michael Yang
a8c5413d06 only generate gpu libs 2024-01-25 15:41:56 -08:00
Michael Yang
5580de4571 archive ollama binaries 2024-01-25 15:40:16 -08:00
Michael Yang
946431d5b0 build cuda and rocm 2024-01-25 15:40:15 -08:00
Michael Yang
0610126049 remove env setting 2024-01-25 15:39:43 -08:00
Jeffrey Morgan
3ebd6a83fc update submodule to cd4fddb29f81d6a1f6d51a0c016bc6b486d68def 2024-01-25 13:54:11 -08:00
Jeffrey Morgan
a64570dcae Fix clearing kv cache between requests with the same prompt (#2186)
* Fix clearing kv cache between requests with the same prompt

* fix powershell script
2024-01-25 13:46:20 -08:00
Patrick Devine
7c40a67841 Save and load sessions (#2063) 2024-01-25 12:12:36 -08:00
Michael Yang
e64b5b07a2 Merge pull request #2181 from ollama/mxyng/stub-lint
stub generate outputs for lint
2024-01-25 11:55:15 -08:00
Michael Yang
9e1e295cdc Merge pull request #2175 from ollama/mxyng/refactor-tensor-read
refactor tensor read
2024-01-25 09:22:42 -08:00
Jeffrey Morgan
a643823f86 Update README.md 2024-01-24 21:36:56 -08:00
Michael Yang
8e5d359a03 stub generate outputs for lint 2024-01-24 17:36:10 -08:00
Daniel Hiltgen
a170888dd4 Merge pull request #2174 from dhiltgen/rocm_real_gpus
More logging for gpu management
2024-01-24 11:09:17 -08:00
Michael Yang
cd22855ef8 refactor tensor read 2024-01-24 10:48:31 -08:00
Daniel Hiltgen
013fd07139 More logging for gpu management
Fix an ordering glitch of dlerr/dlclose and add more logging to help
root cause some crashes users are hitting. This also refines the
function pointer names to use the underlying function names instead
of simplified names for readability.
2024-01-24 10:32:36 -08:00
Daniel Hiltgen
f63dc2db5c Merge pull request #2162 from dhiltgen/rocm_real_gpus
Report more information about GPUs in verbose mode
2024-01-23 17:45:40 -08:00
Jeffrey Morgan
eaa5a396d9 Update README.md 2024-01-23 16:08:15 -08:00
Jeffrey Morgan
8ed22f5d72 Update README.md 2024-01-23 14:38:01 -08:00
Daniel Hiltgen
987c16b2f7 Report more information about GPUs in verbose mode
This adds additional calls to both CUDA and ROCm management libraries to
discover additional attributes about the GPU(s) detected in the system, and
wires up runtime verbosity selection.  When users hit problems with GPUs we can
ask them to run with `OLLAMA_DEBUG=1 ollama serve` and share the results.
2024-01-23 11:37:02 -08:00
Jeffrey Morgan
950f636d64 Update README.md 2024-01-23 10:29:10 -08:00
Jeffrey Morgan
4458efb73a Load all layers on arm64 macOS if model is small enough (#2149) 2024-01-22 17:40:06 -08:00
Daniel Hiltgen
ceea599494 Merge pull request #2150 from dhiltgen/default_version
Set a default version using git describe
2024-01-22 17:38:27 -08:00
Daniel Hiltgen
3005ec74b3 Set a default version using git describe
If a VERSION is not specified, this will generate a version string that
represents the state of the repo.  For example `0.1.21-12-gffaf52e-dirty`
representing 12 commits away from 0.1.21 tag, on commit gffaf52e
and the tree is dirty.
2024-01-22 17:12:20 -08:00
Daniel Hiltgen
0759d8996e Merge pull request #2148 from dhiltgen/intel_mac
Refine Accelerate usage on mac
2024-01-22 16:56:58 -08:00
Daniel Hiltgen
0f5b843319 Refine Accelerate usage on mac
For old macs, accelerate seems to cause crashes, but for
AVX2 capable macs, it does not.
2024-01-22 16:25:56 -08:00
Jeffrey Morgan
ffaf52e1e9 update submodule to 011e8ec577fd135cbc02993d3ea9840c516d6a1c 2024-01-22 15:16:54 -08:00
Michael Yang
940b10b036 Merge pull request #2144 from jmorganca/mxyng/update-faq
faq: update to use launchctl setenv
2024-01-22 13:46:57 -08:00
Daniel Hiltgen
3bc28736cd Merge pull request #2143 from dhiltgen/llm_verbosity
Refine debug logging for llm
2024-01-22 13:19:16 -08:00
Michael Yang
93a756266c faq: update to use launchctl setenv 2024-01-22 13:10:13 -08:00
Daniel Hiltgen
a0a829bf7a Merge pull request #2142 from dhiltgen/debug_on_fail
Debug logging on init failure
2024-01-22 12:29:22 -08:00
Daniel Hiltgen
730dcfcc7a Refine debug logging for llm
This wires up logging in llama.cpp to always go to stderr, and also
turns up logging if OLLAMA_DEBUG is set.
2024-01-22 12:26:49 -08:00
Daniel Hiltgen
27a2d5af54 Debug logging on init failure 2024-01-22 12:08:22 -08:00
Jeffrey Morgan
5f81a33f43 update submodule to 6f9939d (#2115) 2024-01-22 11:56:40 -08:00
Michael Yang
6225fde046 Merge pull request #2102 from jmorganca/mxyng/fix-create-override
fix: remove overwritten model layers
2024-01-22 09:37:48 -08:00
Meng Zhuo
069184562b readline: drop not use min function (#2134) 2024-01-22 08:15:08 -08:00
Daniel Hiltgen
5576bb2348 Merge pull request #2130 from dhiltgen/more_faster
Make CPU builds parallel and customizable AMD GPUs
2024-01-21 16:14:12 -08:00
Daniel Hiltgen
2738837786 Merge pull request #2131 from dhiltgen/probe_cards_at_init
Probe GPUs before backend init
2024-01-21 16:13:47 -08:00
Daniel Hiltgen
ec3764538d Probe GPUs before backend init
Detect potential error scenarios so we can fallback to CPU mode without
hitting asserts.
2024-01-21 15:59:38 -08:00
Daniel Hiltgen
df54c723ae Make CPU builds parallel and customizable AMD GPUs
The linux build now support parallel CPU builds to speed things up.
This also exposes AMD GPU targets as an optional setting for advaced
users who want to alter our default set.
2024-01-21 15:12:21 -08:00
Daniel Hiltgen
fa8c990e58 Merge pull request #2127 from dhiltgen/rocm_container
Combine the 2 Dockerfiles and add ROCm
2024-01-21 11:49:01 -08:00
Daniel Hiltgen
da72235ebf Combine the 2 Dockerfiles and add ROCm
This renames Dockerfile.build to Dockerfile, and adds some new stages
to support 2 modes of building - the build_linux.sh script uses
intermediate stages to extract the artifacts for ./dist, and the default
build generates a container image usable by both cuda and rocm cards.
This required transitioniing the x86 base to the rocm image to avoid
layer bloat.
2024-01-21 11:37:11 -08:00
Jeffrey Morgan
89c4aee29e Unlock mutex when failing to load model (#2117) 2024-01-20 20:54:46 -05:00
Daniel Hiltgen
a447a083f2 Add compute capability 5.0, 7.5, and 8.0 2024-01-20 14:24:05 -08:00
Jeffrey Morgan
f32ea81b21 increase minimum overhead to 1024MiB (#2114) 2024-01-20 17:11:38 -05:00
Daniel Hiltgen
681a914990 Add support for CUDA 5.2 cards 2024-01-20 10:48:43 -08:00
Jeffrey Morgan
4c54f0ddeb sign dylibs on macOS (#2101) 2024-01-19 19:24:11 -05:00
Michael Yang
c08dfaa23d fix: remove overwritten model layers
if create overrides a manifest, first add the older manifest's layers to
the delete map so they can be cleaned up
2024-01-19 14:58:37 -08:00
Daniel Hiltgen
3b76e736ae Merge pull request #2100 from dhiltgen/more_wsl_globs
More WSL paths
2024-01-19 13:41:08 -08:00
Daniel Hiltgen
552db98bf1 More WSL paths 2024-01-19 13:23:29 -08:00
Daniel Hiltgen
fdcdfef620 Merge pull request #2099 from dhiltgen/fix_cuda_model_swap
Switch to local dlopen symbols
2024-01-19 12:22:04 -08:00
Daniel Hiltgen
6a042438af Switch to local dlopen symbols 2024-01-19 11:37:02 -08:00
Michael Yang
27331ae3a8 download: add inactivity monitor
if a download part is inactive for some time, restart it
2024-01-12 15:23:15 -08:00
44 changed files with 2085 additions and 666 deletions

View File

@@ -23,29 +23,72 @@ jobs:
with: with:
go-version: '1.21' go-version: '1.21'
cache: true cache: true
- if: ${{ startsWith(matrix.os, 'windows-') }}
shell: pwsh
run: |
$path = vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
if ($path) {
$path = join-path $path 'Common7\Tools\vsdevcmd.bat'
if (test-path $path) {
cmd /s /c """$path"" $args && set" | where { $_ -match '(\w+)=(.*)' } | foreach {
echo "$($Matches[1])=$($Matches[2])" | Out-File -FilePath $Env:GITHUB_ENV -Encoding utf8 -Append
}
}
}
echo "C:\Program Files\Git\usr\bin" | Out-File -FilePath $Env:GITHUB_PATH -Encoding utf8 -Append
- run: go get ./... - run: go get ./...
- run: go generate -x ./... - run: go generate -x ./...
- uses: actions/upload-artifact@v4 - uses: actions/upload-artifact@v4
with: with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
path: | path: llm/llama.cpp/build/**/lib/*
llm/llama.cpp/build/**/lib/* generate-cuda:
strategy:
matrix:
cuda-version:
- '11.8.0'
runs-on: ubuntu-latest
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
generate-rocm:
strategy:
matrix:
rocm-version:
- '5.7.1'
- '6.0'
runs-on: ubuntu-latest
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl rocm-libs
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
lint: lint:
needs: generate
strategy: strategy:
matrix: matrix:
os: [ubuntu-latest, macos-latest, windows-latest] os: [ubuntu-latest, macos-latest, windows-latest]
@@ -69,10 +112,19 @@ jobs:
with: with:
go-version: '1.21' go-version: '1.21'
cache: false cache: false
- uses: actions/download-artifact@v4 - run: |
with: mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
path: llm/llama.cpp/build if: ${{ startsWith(matrix.os, 'ubuntu-') }}
- run: |
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
touch llm/llama.cpp/ggml-metal.metal
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
if: ${{ startsWith(matrix.os, 'windows-') }}
- uses: golangci/golangci-lint-action@v3 - uses: golangci/golangci-lint-action@v3
test: test:
needs: generate needs: generate
@@ -104,3 +156,7 @@ jobs:
path: llm/llama.cpp/build path: llm/llama.cpp/build
- run: go build - run: go build
- run: go test -v ./... - run: go test -v ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-binaries
path: ollama

View File

@@ -1,27 +1,135 @@
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04 ARG GOLANG_VERSION=1.21.3
ARG CMAKE_VERSION=3.22.1
ARG CUDA_VERSION=11.3.1
ARG TARGETARCH # Copy the minimal context we need to run the generate scripts
ARG GOFLAGS="'-ldflags=-w -s'" FROM scratch AS llm-code
COPY .git .git
COPY .gitmodules .gitmodules
COPY llm llm
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
# Intermediate stage used for ./scripts/build_linux.sh
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
ENV CGO_ENABLED 1
WORKDIR /go/src/github.com/jmorganca/ollama WORKDIR /go/src/github.com/jmorganca/ollama
RUN apt-get update && apt-get install -y git build-essential cmake
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
COPY . . COPY . .
ENV GOARCH=$TARGETARCH COPY --from=cpu_avx-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
ENV GOFLAGS=$GOFLAGS COPY --from=cpu_avx2-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
RUN /usr/local/go/bin/go generate ./... \ COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
&& /usr/local/go/bin/go build . COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
ARG GOFLAGS
ARG CGO_CFLAGS
RUN go build .
FROM ubuntu:22.04 # Intermediate stage used for ./scripts/build_linux.sh
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
ENV CGO_ENABLED 1
ARG GOLANG_VERSION
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
ARG GOFLAGS
ARG CGO_CFLAGS
RUN go build .
# Runtime stages
FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
RUN apt-get update && apt-get install -y ca-certificates RUN apt-get update && apt-get install -y ca-certificates
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
RUN update-pciids
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434 EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0 ENV OLLAMA_HOST 0.0.0.0
# set some environment variable for better NVIDIA compatibility ENTRYPOINT ["/bin/ollama"]
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin CMD ["serve"]
FROM runtime-$TARGETARCH
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64 ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility

View File

@@ -1,99 +0,0 @@
ARG GOLANG_VERSION=1.21.3
ARG CMAKE_VERSION=3.22.1
ARG CUDA_VERSION=11.3.1
# Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code
COPY .git .git
COPY .gitmodules .gitmodules
COPY llm llm
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
ARG CMAKE_VERSION
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
ARG CMAKE_VERSION
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
ARG CMAKE_VERSION
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 centos:7 AS cpu-build-amd64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
RUN sh gen_linux.sh
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
ENV CGO_ENABLED 1
ARG GOFLAGS
ARG CGO_CFLAGS
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
RUN go build .
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
ENV CGO_ENABLED 1
ARG GOLANG_VERSION
ARG GOFLAGS
ARG CGO_CFLAGS
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
RUN go build .
FROM build-$TARGETARCH

View File

@@ -1,8 +1,5 @@
<div align="center"> <div align="center">
<picture> <img alt="ollama" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
<source media="(prefers-color-scheme: dark)" height="200px" srcset="https://github.com/jmorganca/ollama/assets/3325447/56ea1849-1284-4645-8970-956de6e51c3c">
<img alt="logo" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</picture>
</div> </div>
# Ollama # Ollama
@@ -31,6 +28,11 @@ curl https://ollama.ai/install.sh | sh
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub. The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
### Libraries
- [ollama-python](https://github.com/ollama/ollama-python)
- [ollama-js](https://github.com/ollama/ollama-js)
## Quickstart ## Quickstart
To run and chat with [Llama 2](https://ollama.ai/library/llama2): To run and chat with [Llama 2](https://ollama.ai/library/llama2):
@@ -198,18 +200,21 @@ brew install cmake go
``` ```
Then generate dependencies: Then generate dependencies:
``` ```
go generate ./... go generate ./...
``` ```
Then build the binary: Then build the binary:
``` ```
go build . go build .
``` ```
More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md) More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md)
### Running local builds ### Running local builds
Next, start the server: Next, start the server:
``` ```
@@ -248,13 +253,10 @@ curl http://localhost:11434/api/chat -d '{
See the [API documentation](./docs/api.md) for all endpoints. See the [API documentation](./docs/api.md) for all endpoints.
## Integrations
- [ollama-python](https://github.com/jmorganca/ollama-python)
## Community Integrations ## Community Integrations
### Web & Desktop ### Web & Desktop
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt) - [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui) - [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama) - [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -267,7 +269,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Amica](https://github.com/semperai/amica) - [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd) - [chatd](https://github.com/BruceMacD/chatd)
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI) - [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
- [MindMac](https://mindmac.app)
### Terminal ### Terminal
@@ -306,7 +308,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LangChainDart](https://github.com/davidmigloz/langchain_dart) - [LangChainDart](https://github.com/davidmigloz/langchain_dart)
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama) - [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md) - [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
### Mobile ### Mobile
@@ -327,4 +329,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama) - [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama) - [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot) - [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama) - [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)

View File

@@ -34,24 +34,26 @@ func (e StatusError) Error() string {
type ImageData []byte type ImageData []byte
type GenerateRequest struct { type GenerateRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
System string `json:"system"` System string `json:"system"`
Template string `json:"template"` Template string `json:"template"`
Context []int `json:"context,omitempty"` Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"` Raw bool `json:"raw,omitempty"`
Format string `json:"format"` Format string `json:"format"`
Images []ImageData `json:"images,omitempty"` KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
type ChatRequest struct { type ChatRequest struct {
Model string `json:"model"` Model string `json:"model"`
Messages []Message `json:"messages"` Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"` Stream *bool `json:"stream,omitempty"`
Format string `json:"format"` Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@@ -126,8 +128,9 @@ type Runner struct {
} }
type EmbeddingRequest struct { type EmbeddingRequest struct {
Model string `json:"model"` Model string `json:"model"`
Prompt string `json:"prompt"` Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Options map[string]interface{} `json:"options"` Options map[string]interface{} `json:"options"`
} }
@@ -171,6 +174,7 @@ type ShowResponse struct {
Template string `json:"template,omitempty"` Template string `json:"template,omitempty"`
System string `json:"system,omitempty"` System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"` Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
} }
type CopyRequest struct { type CopyRequest struct {
@@ -236,6 +240,7 @@ type GenerateResponse struct {
} }
type ModelDetails struct { type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"` Format string `json:"format"`
Family string `json:"family"` Family string `json:"family"`
Families []string `json:"families"` Families []string `json:"families"`
@@ -411,14 +416,19 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
case float64: case float64:
if t < 0 { if t < 0 {
t = math.MaxFloat64 t = math.MaxFloat64
d.Duration = time.Duration(t)
} else {
d.Duration = time.Duration(t * float64(time.Second))
} }
d.Duration = time.Duration(t)
case string: case string:
d.Duration, err = time.ParseDuration(t) d.Duration, err = time.ParseDuration(t)
if err != nil { if err != nil {
return err return err
} }
if d.Duration < 0 {
mf := math.MaxFloat64
d.Duration = time.Duration(mf)
}
} }
return nil return nil

View File

@@ -25,6 +25,7 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/term" "golang.org/x/term"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
@@ -146,19 +147,68 @@ func RunHandler(cmd *cobra.Command, args []string) error {
} }
name := args[0] name := args[0]
// check if the model exists on the server // check if the model exists on the server
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name}) show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError var statusError api.StatusError
switch { switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound: case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, []string{name}); err != nil { if err := PullHandler(cmd, []string{name}); err != nil {
return err return err
} }
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
if err != nil {
return err
}
case err != nil: case err != nil:
return err return err
} }
return RunGenerate(cmd, args) interactive := true
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
MultiModal: slices.Contains(show.Details.Families, "clip"),
ParentModel: show.Details.ParentModel,
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
return generateInteractive(cmd, opts)
} }
func PushHandler(cmd *cobra.Command, args []string) error { func PushHandler(cmd *cobra.Command, args []string) error {
@@ -410,63 +460,20 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil return nil
} }
func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
return generateInteractive(cmd, opts)
}
type generateContextKey string type generateContextKey string
type runOptions struct { type runOptions struct {
Model string Model string
Prompt string ParentModel string
Messages []api.Message Prompt string
WordWrap bool Messages []api.Message
Format string WordWrap bool
System string Format string
Template string System string
Images []api.ImageData Template string
Options map[string]interface{} Images []api.ImageData
Options map[string]interface{}
MultiModal bool
} }
type displayResponseState struct { type displayResponseState struct {
@@ -628,10 +635,18 @@ func generate(cmd *cobra.Command, opts runOptions) error {
return nil return nil
} }
if opts.MultiModal {
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
if err != nil {
return err
}
}
request := api.GenerateRequest{ request := api.GenerateRequest{
Model: opts.Model, Model: opts.Model,
Prompt: opts.Prompt, Prompt: opts.Prompt,
Context: generateContext, Context: generateContext,
Images: opts.Images,
Format: opts.Format, Format: opts.Format,
System: opts.System, System: opts.System,
Template: opts.Template, Template: opts.Template,

View File

@@ -6,13 +6,16 @@ import (
"io" "io"
"net/http" "net/http"
"os" "os"
"path/filepath"
"regexp" "regexp"
"sort"
"strings" "strings"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/exp/slices" "golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/progress"
"github.com/jmorganca/ollama/readline" "github.com/jmorganca/ollama/readline"
) )
@@ -25,45 +28,82 @@ const (
MultilineTemplate MultilineTemplate
) )
func modelIsMultiModal(cmd *cobra.Command, name string) bool { func loadModel(cmd *cobra.Command, opts *runOptions) error {
// get model details
client, err := api.ClientFromEnvironment() client, err := api.ClientFromEnvironment()
if err != nil { if err != nil {
fmt.Println("error: couldn't connect to ollama server") return err
return false
} }
req := api.ShowRequest{Name: name} p := progress.NewProgress(os.Stderr)
resp, err := client.Show(cmd.Context(), &req) defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model}
showResp, err := client.Show(cmd.Context(), &showReq)
if err != nil { if err != nil {
return false return err
}
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
} }
return slices.Contains(resp.Details.Families, "clip") chatReq := &api.ChatRequest{
Model: opts.Model,
Messages: []api.Message{},
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear()
if len(opts.Messages) > 0 {
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
}
return nil
})
if err != nil {
return err
}
return nil
} }
func generateInteractive(cmd *cobra.Command, opts runOptions) error { func generateInteractive(cmd *cobra.Command, opts runOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model) opts.Messages = make([]api.Message, 0)
// load the model err := loadModel(cmd, &opts)
loadOpts := runOptions{ if err != nil {
Model: opts.Model,
Prompt: "",
Messages: []api.Message{},
}
if _, err := chat(cmd, loadOpts); err != nil {
return err return err
} }
usage := func() { usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:") fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables") fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information") fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /bye Exit") fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command") fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts") fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.") fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "") fmt.Fprintln(os.Stderr, "")
} }
@@ -140,7 +180,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
var sb strings.Builder var sb strings.Builder
var multiline MultilineState var multiline MultilineState
opts.Messages = make([]api.Message, 0)
for { for {
line, err := scanner.Readline() line, err := scanner.Readline()
@@ -174,6 +213,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch multiline { switch multiline {
case MultilineSystem: case MultilineSystem:
opts.System = sb.String() opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset() sb.Reset()
case MultilineTemplate: case MultilineTemplate:
@@ -193,7 +233,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(&sb) fmt.Fprintln(&sb)
multiline = MultilinePrompt multiline = MultilinePrompt
scanner.Prompt.UseAlt = true scanner.Prompt.UseAlt = true
break
} }
case scanner.Pasting: case scanner.Pasting:
fmt.Fprintln(&sb, line) fmt.Fprintln(&sb, line)
@@ -203,6 +242,44 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if err := ListHandler(cmd, args[1:]); err != nil { if err := ListHandler(cmd, args[1:]); err != nil {
return err return err
} }
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /load <modelname>")
continue
}
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil {
return err
}
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.CreateRequest{
Name: args[1],
Modelfile: buildModelfile(opts),
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Println("error: couldn't save model")
return err
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/set"): case strings.HasPrefix(line, "/set"):
args := strings.Fields(line) args := strings.Fields(line)
if len(args) > 1 { if len(args) > 1 {
@@ -278,10 +355,13 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if args[1] == "system" { if args[1] == "system" {
opts.System = sb.String() opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.") fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" { } else if args[1] == "template" {
opts.Template = sb.String() opts.Template = sb.String()
fmt.Println("Set prompt template.") fmt.Println("Set prompt template.")
sb.Reset()
} }
sb.Reset() sb.Reset()
@@ -389,7 +469,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
args := strings.Fields(line) args := strings.Fields(line)
isFile := false isFile := false
if multiModal { if opts.MultiModal {
for _, f := range extractFileNames(line) { for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) { if strings.HasPrefix(f, args[0]) {
isFile = true isFile = true
@@ -411,34 +491,23 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if sb.Len() > 0 && multiline == MultilineNone { if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()} newMessage := api.Message{Role: "user", Content: sb.String()}
if multiModal { if opts.MultiModal {
msg, images, err := extractFileData(sb.String()) msg, images, err := extractFileData(sb.String())
if err != nil { if err != nil {
return err return err
} }
newMessage.Content = msg
// reset the context if we find another image // clear all previous images for better responses
if len(images) > 0 { if len(images) > 0 {
newMessage.Images = append(newMessage.Images, images...) for i := range opts.Messages {
// reset the context for the new image opts.Messages[i].Images = nil
opts.Messages = []api.Message{}
} else {
if len(opts.Messages) > 1 {
newMessage.Images = append(newMessage.Images, opts.Messages[len(opts.Messages)-2].Images...)
} }
} }
if len(newMessage.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.") newMessage.Content = msg
fmt.Println() newMessage.Images = images
sb.Reset()
continue
}
} }
if opts.System != "" {
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
}
opts.Messages = append(opts.Messages, newMessage) opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts) assistant, err := chat(cmd, opts)
@@ -454,6 +523,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
} }
} }
func buildModelfile(opts runOptions) string {
var mf strings.Builder
model := opts.ParentModel
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
}
fmt.Fprintln(&mf)
for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
}
return mf.String()
}
func normalizeFilePath(fp string) string { func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements // Define a map of escaped characters and their replacements
replacements := map[string]string{ replacements := map[string]string{
@@ -500,10 +601,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if os.IsNotExist(err) { if os.IsNotExist(err) {
continue continue
} }
fmt.Printf("Couldn't process image: %q\n", err) fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
return "", imgs, err return "", imgs, err
} }
fmt.Printf("Added image '%s'\n", nfp) fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "") input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data) imgs = append(imgs, data)
} }

View File

@@ -1,9 +1,13 @@
package cmd package cmd
import ( import (
"bytes"
"testing" "testing"
"text/template"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
) )
func TestExtractFilenames(t *testing.T) { func TestExtractFilenames(t *testing.T) {
@@ -49,3 +53,64 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
assert.Contains(t, res[9], "ten.svg") assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "E:") assert.Contains(t, res[9], "E:")
} }
func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
},
Options: map[string]interface{}{},
}
opts.Options["temperature"] = 0.9
opts.Options["seed"] = 42
opts.Options["penalize_newline"] = false
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, opts)
assert.Nil(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf)
}

View File

@@ -50,7 +50,8 @@ development and runtime packages.
Typically the build scripts will auto-detect CUDA, however, if your Linux distro Typically the build scripts will auto-detect CUDA, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Then generate dependencies: Then generate dependencies:
@@ -74,7 +75,8 @@ Typically the build scripts will auto-detect ROCm, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `ROCM_PATH` to the location of the ROCm specifying an environment variable `ROCM_PATH` to the location of the ROCm
install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
CLBlast install (typically `/usr/lib/cmake/CLBlast`). CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
``` ```
go generate ./... go generate ./...

View File

@@ -8,35 +8,38 @@ To upgrade Ollama, run the installation process again. On the Mac, click the Oll
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs. Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## How do I use Ollama server environment variables on Mac ## How do I configure Ollama server?
On macOS, Ollama runs in the background and is managed by the menubar app. If adding environment variables, Ollama will need to be run manually. Ollama server can be configured with environment variables.
1. Click the menubar icon for Ollama and choose **Quit Ollama**. ### Setting environment variables on Mac
2. Open a new terminal window and run the following command (this example uses `OLLAMA_HOST` with an IP address of `123.1.1.1`):
```bash If Ollama is run as a macOS application, environment variables should be set using `launchctl`:
OLLAMA_HOST=123.1.1.1 ollama serve
```
## How do I use Ollama server environment variables on Linux? 1. For each environment variable, call `launchctl setenv`.
If Ollama is installed with the install script, a systemd service was created, running as the Ollama user. To add an environment variable, such as OLLAMA_HOST, follow these steps: ```bash
launchctl setenv OLLAMA_HOST "0.0.0.0"
```
1. Create a `systemd` drop-in directory and add a config file. This is only needed once. 2. Restart Ollama application.
```bash ### Setting environment variables on Linux
mkdir -p /etc/systemd/system/ollama.service.d
echo '[Service]' >>/etc/systemd/system/ollama.service.d/environment.conf
```
2. For each environment variable, add it to the config file: If Ollama is run as a systemd service, environment variables should be set using `systemctl`:
```bash 1. Edit the systemd service by calling `systemctl edit ollama.service`. This will open an editor.
echo 'Environment="OLLAMA_HOST=0.0.0.0:11434"' >>/etc/systemd/system/ollama.service.d/environment.conf
```
3. Reload `systemd` and restart Ollama: 2. For each environment variable, add a line `Environment` under section `[Service]`:
```ini
[Service]
Environment="OLLAMA_HOST=0.0.0.0"
```
3. Save and exit.
4. Reload `systemd` and restart Ollama:
```bash ```bash
systemctl daemon-reload systemctl daemon-reload
@@ -45,26 +48,26 @@ If Ollama is installed with the install script, a systemd service was created, r
## How can I expose Ollama on my network? ## How can I expose Ollama on my network?
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable. Refer to the section above for how to use environment variables on your platform. Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## How can I allow additional web origins to access Ollama? ## How can I allow additional web origins to access Ollama?
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable. For example, to add all ports on 192.168.1.1 and https://example.com, use: Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
```shell Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com
```
Refer to the section above for how to use environment variables on your platform.
## Where are models stored? ## Where are models stored?
- macOS: `~/.ollama/models`. - macOS: `~/.ollama/models`.
- Linux: `/usr/share/ollama/.ollama/models` - Linux: `/usr/share/ollama/.ollama/models`
## How do I set them to a different location? ### How do I set them to a different location?
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory. Refer to the section above for how to use environment variables on your platform. If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory.
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
## Does Ollama send my prompts and answers back to Ollama.ai to use in any way? ## Does Ollama send my prompts and answers back to Ollama.ai to use in any way?

View File

@@ -19,6 +19,7 @@ A model file is the blueprint to create and share models with Ollama.
- [SYSTEM](#system) - [SYSTEM](#system)
- [ADAPTER](#adapter) - [ADAPTER](#adapter)
- [LICENSE](#license) - [LICENSE](#license)
- [MESSAGE](#message)
- [Notes](#notes) - [Notes](#notes)
## Format ## Format
@@ -38,6 +39,7 @@ INSTRUCTION arguments
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. | | [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. | | [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. | | [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
## Examples ## Examples
@@ -205,6 +207,19 @@ LICENSE """
""" """
``` ```
### MESSAGE
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
```modelfile
MESSAGE user Is Toronto in Canada?
MESSAGE assistant yes
MESSAGE user Is Sacramento in Canada?
MESSAGE assistant no
MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
## Notes ## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments. - the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.

View File

@@ -12,6 +12,13 @@ On Linux systems with systemd, the logs can be found with this command:
journalctl -u ollama journalctl -u ollama
``` ```
When you run Ollama in a container, the logs go to stdout/stderr in the container:
```shell
docker logs <container-name>
```
(Use `docker ps` to find the container name)
If manually running `ollama serve` in a terminal, the logs will be on that terminal. If manually running `ollama serve` in a terminal, the logs will be on that terminal.
Join the [Discord](https://discord.gg/ollama) for help interpreting the logs. Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.

View File

@@ -16,6 +16,7 @@ import (
"os" "os"
"path/filepath" "path/filepath"
"runtime" "runtime"
"strconv"
"strings" "strings"
"sync" "sync"
"unsafe" "unsafe"
@@ -29,8 +30,8 @@ type handles struct {
var gpuMutex sync.Mutex var gpuMutex sync.Mutex
var gpuHandles *handles = nil var gpuHandles *handles = nil
// With our current CUDA compile flags, 5.2 and older will not work properly // With our current CUDA compile flags, older than 5.0 will not work properly
const CudaComputeMajorMin = 6 var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library // Possible locations for the nvidia-ml library
var CudaLinuxGlobs = []string{ var CudaLinuxGlobs = []string{
@@ -38,12 +39,15 @@ var CudaLinuxGlobs = []string{
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*", "/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*", "/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*", "/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*", "/opt/cuda/lib64/libnvidia-ml.so*",
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*", "/usr/lib*/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*", "/usr/local/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*", "/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*", "/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
// TODO: are these stubs ever valid?
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
} }
var CudaWindowsGlobs = []string{ var CudaWindowsGlobs = []string{
@@ -118,33 +122,60 @@ func GetGPUInfo() GpuInfo {
initGPUHandles() initGPUHandles()
} }
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
cpuVariant := GetCPUVariant()
if cpuVariant == "" && runtime.GOARCH == "amd64" {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
}
var memInfo C.mem_info_t var memInfo C.mem_info_t
resp := GpuInfo{} resp := GpuInfo{}
if gpuHandles.cuda != nil { if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.cuda_check_vram(*gpuHandles.cuda, &memInfo) C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
} else { } else if memInfo.count > 0 {
// Verify minimum compute capability // Verify minimum compute capability
var cc C.cuda_compute_capability_t var cc C.cuda_compute_capability_t
C.cuda_compute_capability(*gpuHandles.cuda, &cc) C.cuda_compute_capability(*gpuHandles.cuda, &cc)
if cc.err != nil { if cc.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))) slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
C.free(unsafe.Pointer(cc.err)) C.free(unsafe.Pointer(cc.err))
} else if cc.major >= CudaComputeMajorMin { } else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)) slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
resp.Library = "cuda" resp.Library = "cuda"
} else { } else {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)) slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
} }
} }
} else if gpuHandles.rocm != nil { } else if gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
C.rocm_check_vram(*gpuHandles.rocm, &memInfo) C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err)) C.free(unsafe.Pointer(memInfo.err))
} else { } else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
// Only one GPU detected and it appears to be an integrated GPU - skip it
slog.Info("ROCm unsupported integrated GPU detected")
} else if memInfo.count > 0 {
if memInfo.igpu_index >= 0 {
// We have multiple GPUs reported, and one of them is an integrated GPU
// so we have to set the env var to bypass it
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
val := os.Getenv("ROCR_VISIBLE_DEVICES")
if val == "" {
devices := []string{}
for i := 0; i < int(memInfo.count); i++ {
if i == int(memInfo.igpu_index) {
continue
}
devices = append(devices, strconv.Itoa(i))
}
val = strings.Join(devices, ",")
os.Setenv("ROCR_VISIBLE_DEVICES", val)
}
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
}
resp.Library = "rocm" resp.Library = "rocm"
var version C.rocm_version_resp_t var version C.rocm_version_resp_t
C.rocm_get_version(*gpuHandles.rocm, &version) C.rocm_get_version(*gpuHandles.rocm, &version)
@@ -160,7 +191,7 @@ func GetGPUInfo() GpuInfo {
if resp.Library == "" { if resp.Library == "" {
C.cpu_check_ram(&memInfo) C.cpu_check_ram(&memInfo)
resp.Library = "cpu" resp.Library = "cpu"
resp.Variant = GetCPUVariant() resp.Variant = cpuVariant
} }
if memInfo.err != nil { if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err))) slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
@@ -190,7 +221,15 @@ func getCPUMem() (memInfo, error) {
func CheckVRAM() (int64, error) { func CheckVRAM() (int64, error) {
gpuInfo := GetGPUInfo() gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") { if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
return int64(gpuInfo.FreeMemory), nil // leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
overhead := gpuInfo.FreeMemory / 10
gpus := uint64(gpuInfo.DeviceCount)
if overhead < gpus*1024*1024*1024 {
overhead = gpus * 1024 * 1024 * 1024
}
avail := int64(gpuInfo.FreeMemory - overhead)
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
return avail, nil
} }
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
@@ -252,6 +291,7 @@ func FindGPULibs(baseLibName string, patterns []string) []string {
func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t { func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
var resp C.cuda_init_resp_t var resp C.cuda_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range cudaLibPaths { for _, libPath := range cudaLibPaths {
lib := C.CString(libPath) lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib)) defer C.free(unsafe.Pointer(lib))
@@ -268,6 +308,7 @@ func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t { func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
var resp C.rocm_init_resp_t var resp C.rocm_init_resp_t
resp.rh.verbose = getVerboseState()
for _, libPath := range rocmLibPaths { for _, libPath := range rocmLibPaths {
lib := C.CString(libPath) lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib)) defer C.free(unsafe.Pointer(lib))
@@ -281,3 +322,10 @@ func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
} }
return nil return nil
} }
func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
return C.uint16_t(1)
}
return C.uint16_t(0)
}

View File

@@ -27,6 +27,13 @@
#endif #endif
#define LOG(verbose, ...) \
do { \
if (verbose) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
@@ -35,6 +42,7 @@ typedef struct mem_info {
uint64_t total; uint64_t total;
uint64_t free; uint64_t free;
unsigned int count; unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing char *err; // If non-nill, caller responsible for freeing
} mem_info_t; } mem_info_t;

View File

@@ -4,8 +4,6 @@
#include <string.h> #include <string.h>
#define CUDA_LOOKUP_SIZE 6
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) { void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
nvmlReturn_t ret; nvmlReturn_t ret;
resp->err = NULL; resp->err = NULL;
@@ -16,18 +14,26 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
struct lookup { struct lookup {
char *s; char *s;
void **p; void **p;
} l[CUDA_LOOKUP_SIZE] = { } l[] = {
{"nvmlInit_v2", (void *)&resp->ch.initFn}, {"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.shutdownFn}, {"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.getHandle}, {"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.getMemInfo}, {"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.getCount}, {"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.getComputeCapability}, {"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
}; };
resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY); resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY);
if (!resp->ch.handle) { if (!resp->ch.handle) {
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", cuda_lib_path, msg);
snprintf(buf, buflen, snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s", "Unable to load %s library to query for Nvidia GPUs: %s",
cuda_lib_path, msg); cuda_lib_path, msg);
@@ -36,12 +42,19 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
return; return;
} }
for (i = 0; i < CUDA_LOOKUP_SIZE; i++) { // TODO - fix this to use a null terminated list // TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", cuda_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s); *l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) { if (!l[i].p) {
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg); msg);
free(msg); free(msg);
@@ -50,15 +63,23 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
} }
} }
ret = (*resp->ch.initFn)(); ret = (*resp->ch.nvmlInit_v2)();
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle); UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL; resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret); snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return;
} }
return; // Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
} }
void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) { void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
@@ -75,7 +96,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
return; return;
} }
ret = (*h.getCount)(&resp->count); ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret); snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
@@ -85,19 +106,57 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
resp->total = 0; resp->total = 0;
resp->free = 0; resp->free = 0;
for (i = 0; i < resp->count; i++) { for (i = 0; i < resp->count; i++) {
ret = (*h.getHandle)(i, &device); ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
ret = (*h.getMemInfo)(device, &memInfo); ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret); snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.free);
resp->total += memInfo.total; resp->total += memInfo.total;
resp->free += memInfo.free; resp->free += memInfo.free;
@@ -122,7 +181,7 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
} }
unsigned int devices; unsigned int devices;
ret = (*h.getCount)(&devices); ret = (*h.nvmlDeviceGetCount_v2)(&devices);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret); snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
@@ -130,14 +189,14 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
} }
for (i = 0; i < devices; i++) { for (i = 0; i < devices; i++) {
ret = (*h.getHandle)(i, &device); ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret); snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
ret = (*h.getComputeCapability)(device, &major, &minor); ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) { if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret); snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf); resp->err = strdup(buf);

View File

@@ -15,14 +15,26 @@ typedef struct nvmlMemory_st {
unsigned long long used; unsigned long long used;
} nvmlMemory_t; } nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct cuda_handle { typedef struct cuda_handle {
void *handle; void *handle;
nvmlReturn_t (*initFn)(void); uint16_t verbose;
nvmlReturn_t (*shutdownFn)(void); nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*getHandle)(unsigned int, nvmlDevice_t *); nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*getMemInfo)(nvmlDevice_t, nvmlMemory_t *); nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*getCount)(unsigned int *); nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*getComputeCapability)(nvmlDevice_t, int* major, int* minor); nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
} cuda_handle_t; } cuda_handle_t;
typedef struct cuda_init_resp { typedef struct cuda_init_resp {

View File

@@ -4,8 +4,6 @@
#include <string.h> #include <string.h>
#define ROCM_LOOKUP_SIZE 5
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) { void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
rsmi_status_t ret; rsmi_status_t ret;
resp->err = NULL; resp->err = NULL;
@@ -15,13 +13,22 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
struct lookup { struct lookup {
char *s; char *s;
void **p; void **p;
} l[ROCM_LOOKUP_SIZE] = { } l[] = {
{"rsmi_init", (void *)&resp->rh.initFn}, {"rsmi_init", (void *)&resp->rh.rsmi_init},
{"rsmi_shut_down", (void *)&resp->rh.shutdownFn}, {"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down},
{"rsmi_dev_memory_total_get", (void *)&resp->rh.totalMemFn}, {"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get},
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.usageMemFn}, {"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get},
{"rsmi_version_get", (void *)&resp->rh.versionGetFn}, {"rsmi_version_get", (void *)&resp->rh.rsmi_version_get},
// { "rsmi_dev_id_get", (void*)&resp->rh.getHandle }, {"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices},
{"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get},
{"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get},
{"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get},
{"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get},
{"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get},
{"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get},
{"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get},
{"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get},
{NULL, NULL},
}; };
resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY); resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
@@ -35,12 +42,19 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
return; return;
} }
for (i = 0; i < ROCM_LOOKUP_SIZE; i++) { // TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s);
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s); *l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
if (!l[i].p) { if (!l[i].p) {
UNLOAD_LIBRARY(resp->rh.handle);
resp->rh.handle = NULL; resp->rh.handle = NULL;
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();
LOG(resp->rh.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->rh.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s, snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg); msg);
free(msg); free(msg);
@@ -49,8 +63,9 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
} }
} }
ret = (*resp->rh.initFn)(0); ret = (*resp->rh.rsmi_init)(0);
if (ret != RSMI_STATUS_SUCCESS) { if (ret != RSMI_STATUS_SUCCESS) {
LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret);
UNLOAD_LIBRARY(resp->rh.handle); UNLOAD_LIBRARY(resp->rh.handle);
resp->rh.handle = NULL; resp->rh.handle = NULL;
snprintf(buf, buflen, "rocm vram init failure: %d", ret); snprintf(buf, buflen, "rocm vram init failure: %d", ret);
@@ -62,8 +77,7 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) { void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
resp->err = NULL; resp->err = NULL;
// uint32_t num_devices; resp->igpu_index = -1;
// uint16_t device;
uint64_t totalMem = 0; uint64_t totalMem = 0;
uint64_t usedMem = 0; uint64_t usedMem = 0;
rsmi_status_t ret; rsmi_status_t ret;
@@ -76,47 +90,101 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
return; return;
} }
// TODO - iterate through devices... ret = ret = (*h.rsmi_num_monitor_devices)(&resp->count);
// rsmi_num_monitor_devices(&num_devices);
// ret = (*h.getHandle)(0, &device);
// if (ret != RSMI_STATUS_SUCCESS) {
// printf("rocm vram device lookup failure: %d\n", ret);
// return -1;
// }
// Get total memory - used memory for available memory
ret = (*h.totalMemFn)(0, RSMI_MEM_TYPE_VRAM, &totalMem);
if (ret != RSMI_STATUS_SUCCESS) { if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret); snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.usageMemFn)(0, RSMI_MEM_TYPE_VRAM, &usedMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
resp->err = strdup(buf); resp->err = strdup(buf);
return; return;
} }
LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count);
// TODO: set this to the actual number of devices resp->total = 0;
resp->count = 1; resp->free = 0;
resp->total = totalMem; for (i = 0; i < resp->count; i++) {
resp->free = totalMem - usedMem; if (h.verbose) {
return; // When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.rsmi_dev_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_brand_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf);
}
ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf);
}
}
// Get total memory - used memory for available memory
ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
if (totalMem < 1024 * 1024 * 1024) {
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
resp->igpu_index = i;
} else {
resp->total += totalMem;
resp->free += totalMem - usedMem;
}
}
} }
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) { void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
const int buflen = 256; const int buflen = 256;
char buf[buflen + 1]; char buf[buflen + 1];
if (h.handle == NULL) { if (h.handle == NULL) {
resp->str = strdup("nvml handle not initialized"); resp->str = strdup("rocm handle not initialized");
resp->status = 1; resp->status = 1;
return; return;
} }
rsmi_version_t ver; rsmi_version_t ver;
rsmi_status_t ret; rsmi_status_t ret;
ret = h.versionGetFn(&ver); ret = h.rsmi_version_get(&ver);
if (ret != RSMI_STATUS_SUCCESS) { if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "unexpected response on version lookup %d", ret); snprintf(buf, buflen, "unexpected response on version lookup %d", ret);
resp->status = 1; resp->status = 1;
@@ -127,4 +195,4 @@ void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
resp->str = strdup(buf); resp->str = strdup(buf);
} }
#endif // __APPLE__ #endif // __APPLE__

View File

@@ -24,12 +24,21 @@ typedef enum rsmi_memory_type {
typedef struct rocm_handle { typedef struct rocm_handle {
void *handle; void *handle;
rsmi_status_t (*initFn)(uint64_t); uint16_t verbose;
rsmi_status_t (*shutdownFn)(void); rsmi_status_t (*rsmi_init)(uint64_t);
rsmi_status_t (*totalMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *); rsmi_status_t (*rsmi_shut_down)(void);
rsmi_status_t (*usageMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *); rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*versionGetFn) (rsmi_version_t *version); rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
// rsmi_status_t (*getHandle)(uint32_t, uint16_t *); rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version);
rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *);
rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *);
rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t);
rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t);
} rocm_handle_t; } rocm_handle_t;
typedef struct rocm_init_resp { typedef struct rocm_init_resp {

View File

@@ -59,7 +59,7 @@ void dyn_init(const char *libPath, struct dynamic_llama_server *s,
}; };
printf("loading library %s\n", libPath); printf("loading library %s\n", libPath);
s->handle = LOAD_LIBRARY(libPath, RTLD_GLOBAL|RTLD_NOW); s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
if (!s->handle) { if (!s->handle) {
err->id = -1; err->id = -1;
char *msg = LOAD_ERR(); char *msg = LOAD_ERR();

View File

@@ -4,7 +4,7 @@ package llm
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server #cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 #cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds #cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable #cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE #cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE #cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG #cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
@@ -136,12 +136,21 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
sparams.n_threads = C.uint(opts.NumThread) sparams.n_threads = C.uint(opts.NumThread)
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
sparams.verbose_logging = C.bool(true)
} else {
sparams.verbose_logging = C.bool(false)
}
slog.Info("Initializing llama server") slog.Info("Initializing llama server")
initResp := newExtServerResp(128) initResp := newExtServerResp(128)
defer freeExtServerResp(initResp) defer freeExtServerResp(initResp)
C.dyn_llama_server_init(llm.s, &sparams, &initResp) C.dyn_llama_server_init(llm.s, &sparams, &initResp)
if initResp.id < 0 { if initResp.id < 0 {
return nil, extServerResponseToErr(initResp) mutex.Unlock()
err := extServerResponseToErr(initResp)
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
return nil, err
} }
slog.Info("Starting llama main loop") slog.Info("Starting llama main loop")
@@ -152,13 +161,10 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error { func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128) resp := newExtServerResp(128)
defer freeExtServerResp(resp) defer freeExtServerResp(resp)
var imageData []ImageData
if len(predict.Images) > 0 { if len(predict.Images) > 0 {
for cnt, i := range predict.Images { slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
} }
slog.Info(fmt.Sprintf("loaded %d images", len(imageData)))
request := map[string]any{ request := map[string]any{
"prompt": predict.Prompt, "prompt": predict.Prompt,
@@ -180,7 +186,8 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"penalize_nl": predict.Options.PenalizeNewline, "penalize_nl": predict.Options.PenalizeNewline,
"seed": predict.Options.Seed, "seed": predict.Options.Seed,
"stop": predict.Options.Stop, "stop": predict.Options.Stop,
"image_data": imageData, "image_data": predict.Images,
"cache_prompt": true,
} }
if predict.Format == "json" { if predict.Format == "json" {

View File

@@ -3,22 +3,44 @@
// Necessary evil since the server types are not defined in a header // Necessary evil since the server types are not defined in a header
#include "server.cpp" #include "server.cpp"
// Low level API access to verify GPU access
#if defined(GGML_USE_CUBLAS)
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define cudaGetDevice hipGetDevice
#define cudaError_t hipError_t
#define cudaSuccess hipSuccess
#define cudaGetErrorString hipGetErrorString
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#endif // defined(GGML_USE_HIPBLAS)
#endif // GGML_USE_CUBLAS
// Expose the llama server as a callable extern "C" API // Expose the llama server as a callable extern "C" API
llama_server_context *llama = NULL; llama_server_context *llama = NULL;
std::atomic<bool> ext_server_running(false);
std::thread ext_server_thread; std::thread ext_server_thread;
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) { void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
#if SERVER_VERBOSE != 1
log_disable();
#endif
LOG_TEE("system info: %s", llama_print_system_info());
assert(err != NULL && sparams != NULL); assert(err != NULL && sparams != NULL);
log_set_target(stderr);
if (!sparams->verbose_logging) {
server_verbose = true;
log_disable();
}
LOG_TEE("system info: %s\n", llama_print_system_info());
err->id = 0; err->id = 0;
err->msg[0] = '\0'; err->msg[0] = '\0';
try { try {
llama = new llama_server_context; llama = new llama_server_context;
log_set_target(stdout);
gpt_params params; gpt_params params;
params.n_ctx = sparams->n_ctx; params.n_ctx = sparams->n_ctx;
params.n_batch = sparams->n_batch; params.n_batch = sparams->n_batch;
@@ -60,6 +82,18 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
params.mmproj = std::string(sparams->mmproj); params.mmproj = std::string(sparams->mmproj);
} }
#if defined(GGML_USE_CUBLAS)
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
LOG_TEE("Performing pre-initialization of GPU\n");
int id;
cudaError_t cudaErr = cudaGetDevice(&id);
if (cudaErr != cudaSuccess) {
err->id = -1;
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
return;
}
#endif
llama_backend_init(params.numa); llama_backend_init(params.numa);
// load the model // load the model
@@ -88,18 +122,23 @@ void llama_server_start() {
assert(llama != NULL); assert(llama != NULL);
// TODO mutex to protect thread creation // TODO mutex to protect thread creation
ext_server_thread = std::thread([&]() { ext_server_thread = std::thread([&]() {
ext_server_running = true;
try { try {
LOG_TEE("llama server main loop starting\n"); LOG_TEE("llama server main loop starting\n");
ggml_time_init(); ggml_time_init();
while (ext_server_running.load()) { llama->queue_tasks.on_new_task(std::bind(
if (!llama->update_slots()) { &llama_server_context::process_single_task, llama, std::placeholders::_1));
LOG_TEE( llama->queue_tasks.on_finish_multitask(std::bind(
"unexpected error in llama server update_slots - exiting main " &llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
"loop\n"); llama->queue_tasks.on_all_tasks_finished(std::bind(
break; &llama_server_context::run_on_all_tasks_finished, llama));
} llama->queue_results.on_multitask_update(std::bind(
} &llama_server_queue::update_multitask,
&llama->queue_tasks,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
));
llama->queue_tasks.start_loop();
} catch (std::exception &e) { } catch (std::exception &e) {
LOG_TEE("caught exception in llama server main loop: %s\n", e.what()); LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
} catch (...) { } catch (...) {
@@ -112,13 +151,10 @@ void llama_server_start() {
void llama_server_stop() { void llama_server_stop() {
assert(llama != NULL); assert(llama != NULL);
// TODO - too verbose, remove once things are solid LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
LOG_TEE("requesting llama server shutdown\n"); // This may take a while for any pending tasks to drain
ext_server_running = false; // TODO - consider a timeout to cancel tasks if it's taking too long
llama->queue_tasks.terminate();
// unblocks the update_slots() loop so it can clean up and exit
llama->request_cancel(0);
ext_server_thread.join(); ext_server_thread.join();
delete llama; delete llama;
llama = NULL; llama = NULL;
@@ -131,7 +167,9 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
resp->msg[0] = '\0'; resp->msg[0] = '\0';
try { try {
json data = json::parse(json_req); json data = json::parse(json_req);
resp->id = llama->request_completion(data, false, false, -1); resp->id = llama->queue_tasks.get_new_id();
llama->queue_results.add_waiting_task_id(resp->id);
llama->request_completion(resp->id, data, false, false, -1);
} catch (std::exception &e) { } catch (std::exception &e) {
snprintf(resp->msg, resp->msg_len, "exception %s", e.what()); snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
} catch (...) { } catch (...) {
@@ -149,16 +187,22 @@ void llama_server_completion_next_result(const int task_id,
resp->json_resp = NULL; resp->json_resp = NULL;
std::string result_json; std::string result_json;
try { try {
task_result result = llama->next_result(task_id); task_result result = llama->queue_results.recv(task_id);
result_json = result_json =
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace); result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
resp->id = result.id; resp->id = result.id;
resp->stop = result.stop; resp->stop = result.stop;
resp->error = result.error; resp->error = result.error;
if (result.error) { if (result.error) {
LOG_TEE("next result cancel on error\n");
llama->request_cancel(task_id); llama->request_cancel(task_id);
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} else if (result.stop) { } else if (result.stop) {
LOG_TEE("next result cancel on stop\n");
llama->request_cancel(task_id); llama->request_cancel(task_id);
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} }
} catch (std::exception &e) { } catch (std::exception &e) {
resp->error = true; resp->error = true;
@@ -189,6 +233,7 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
err->msg[0] = '\0'; err->msg[0] = '\0';
try { try {
llama->request_cancel(task_id); llama->request_cancel(task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} catch (std::exception &e) { } catch (std::exception &e) {
err->id = -1; err->id = -1;
snprintf(err->msg, err->msg_len, "exception %s", e.what()); snprintf(err->msg, err->msg_len, "exception %s", e.what());
@@ -273,13 +318,15 @@ void llama_server_embedding(const char *json_req, char **json_resp,
} else { } else {
prompt = ""; prompt = "";
} }
const int task_id = llama->request_completion( const int task_id = llama->queue_tasks.get_new_id();
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1); llama->queue_results.add_waiting_task_id(task_id);
task_result result = llama->next_result(task_id); llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
task_result result = llama->queue_results.recv(task_id);
std::string result_json = result.result_json.dump(); std::string result_json = result.result_json.dump();
const std::string::size_type size = result_json.size() + 1; const std::string::size_type size = result_json.size() + 1;
*json_resp = new char[size]; *json_resp = new char[size];
snprintf(*json_resp, size, "%s", result_json.c_str()); snprintf(*json_resp, size, "%s", result_json.c_str());
llama->queue_results.remove_waiting_task_id(task_id);
} catch (std::exception &e) { } catch (std::exception &e) {
err->id = -1; err->id = -1;
snprintf(err->msg, err->msg_len, "exception %s", e.what()); snprintf(err->msg, err->msg_len, "exception %s", e.what());

View File

@@ -45,6 +45,7 @@ typedef struct ext_server_params {
bool embedding; // get only sentence embedding bool embedding; // get only sentence embedding
ext_server_lora_adapter_t *lora_adapters; ext_server_lora_adapter_t *lora_adapters;
char *mmproj; char *mmproj;
bool verbose_logging; // Enable verbose logging of the server
} ext_server_params_t; } ext_server_params_t;
typedef struct ext_server_task_result { typedef struct ext_server_task_result {

View File

@@ -39,6 +39,9 @@ init_vars() {
*) *)
;; ;;
esac esac
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
fi
} }
git_module_setup() { git_module_setup() {
@@ -61,6 +64,19 @@ apply_patches() {
if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
fi fi
if [ -n "$(ls -A ../patches/*.diff)" ]; then
# apply temporary patches until fix is upstream
for patch in ../patches/*.diff; do
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
(cd ${LLAMACPP_DIR}; git checkout ${file})
done
done
for patch in ../patches/*.diff; do
(cd ${LLAMACPP_DIR} && git apply ${patch})
done
fi
# Avoid duplicate main symbols when we link into the cgo binary # Avoid duplicate main symbols when we link into the cgo binary
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp && sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp
@@ -83,8 +99,9 @@ build() {
compress_libs() { compress_libs() {
echo "Compressing payloads to reduce overall binary size..." echo "Compressing payloads to reduce overall binary size..."
pids="" pids=""
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
gzip --best ${lib} & gzip --best -f ${lib} &
pids+=" $!" pids+=" $!"
done done
echo echo
@@ -97,4 +114,12 @@ compress_libs() {
# Keep the local tree clean after we're done with the build # Keep the local tree clean after we're done with the build
cleanup() { cleanup() {
(cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp) (cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp)
if [ -n "$(ls -A ../patches/*.diff)" ]; then
for patch in ../patches/*.diff; do
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
(cd ${LLAMACPP_DIR}; git checkout ${file})
done
done
fi
} }

View File

@@ -12,7 +12,13 @@ init_vars
git_module_setup git_module_setup
apply_patches apply_patches
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_ACCELERATE=off" sign() {
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp --deep --options=runtime --sign "$APPLE_IDENTITY" --identifier ai.ollama.ollama $1
fi
}
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
case "${GOARCH}" in case "${GOARCH}" in
"amd64") "amd64")
@@ -21,10 +27,11 @@ case "${GOARCH}" in
# #
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta) # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
# #
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu" BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
echo "Building LCD CPU" echo "Building LCD CPU"
build build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
compress_libs compress_libs
# #
@@ -32,10 +39,11 @@ case "${GOARCH}" in
# Approximately 400% faster than LCD on same CPU # Approximately 400% faster than LCD on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx" BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU" echo "Building AVX CPU"
build build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
compress_libs compress_libs
# #
@@ -43,17 +51,20 @@ case "${GOARCH}" in
# Approximately 10% faster than AVX on same CPU # Approximately 10% faster than AVX on same CPU
# #
init_vars init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2" BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU" echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
compress_libs compress_libs
;; ;;
"arm64") "arm64")
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on -DLLAMA_ACCELERATE=on ${CMAKE_DEFS}" CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal" BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders" EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
compress_libs compress_libs
;; ;;
*) *)

View File

@@ -16,6 +16,10 @@ set -o pipefail
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference # See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
amdGPUs() { amdGPUs() {
if [ -n "${AMDGPU_TARGETS}" ]; then
echo "${AMDGPU_TARGETS}"
return
fi
GPU_LIST=( GPU_LIST=(
"gfx803" "gfx803"
"gfx900" "gfx900"
@@ -73,36 +77,42 @@ if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake # -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off" COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off"
# if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta) #
# # CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" #
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
echo "Building LCD CPU" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
build echo "Building LCD CPU"
compress_libs build
compress_libs
fi
# if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx" ]; then
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance #
# Approximately 400% faster than LCD on same CPU # ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# # Approximately 400% faster than LCD on same CPU
init_vars #
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}" init_vars
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
echo "Building AVX CPU" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
build echo "Building AVX CPU"
compress_libs build
compress_libs
fi
# if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
# ~2013 CPU Dynamic library #
# Approximately 10% faster than AVX on same CPU # ~2013 CPU Dynamic library
# # Approximately 10% faster than AVX on same CPU
init_vars #
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}" init_vars
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2" CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
echo "Building AVX2 CPU" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
build echo "Building AVX2 CPU"
compress_libs build
compress_libs
fi
fi fi
else else
echo "Skipping CPU generation step as requested" echo "Skipping CPU generation step as requested"
@@ -125,7 +135,7 @@ if [ -d "${CUDA_LIB_DIR}" ]; then
if [ -n "${CUDA_MAJOR}" ]; then if [ -n "${CUDA_MAJOR}" ]; then
CUDA_VARIANT=_v${CUDA_MAJOR} CUDA_VARIANT=_v${CUDA_MAJOR}
fi fi
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}" CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}" BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda" EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
build build

View File

@@ -25,6 +25,11 @@ function init_vars {
} }
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path $script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path $script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
} else {
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
}
} }
function git_module_setup { function git_module_setup {
@@ -40,6 +45,29 @@ function apply_patches {
if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) { if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama' Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
} }
# Apply temporary patches until fix is upstream
$patches = Get-ChildItem "../patches/*.diff"
foreach ($patch in $patches) {
# Extract file paths from the patch file
$filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
$parts = $_ -split ' '
($parts[1] -split '/', 2)[1]
}
# Checkout each file
foreach ($file in $filePaths) {
Set-Location -Path ${script:llamacppDir}
git checkout $file
}
}
# Apply each patch
foreach ($patch in $patches) {
Set-Location -Path ${script:llamacppDir}
git apply $patch.FullName
}
# Avoid duplicate main symbols when we link into the cgo binary # Avoid duplicate main symbols when we link into the cgo binary
$content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp" $content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
$content = $content -replace 'int main\(', 'int __main(' $content = $content -replace 'int main\(', 'int __main('
@@ -76,7 +104,7 @@ function compress_libs {
write-host "Compressing dlls..." write-host "Compressing dlls..."
$libs = dir "${script:buildDir}/lib/*.dll" $libs = dir "${script:buildDir}/lib/*.dll"
foreach ($file in $libs) { foreach ($file in $libs) {
& "$script:GZIP" --best $file & "$script:GZIP" --best -f $file
} }
} }
@@ -128,7 +156,7 @@ if ($null -ne $script:CUDA_LIB_DIR) {
} }
init_vars init_vars
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT" $script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on") $script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build build
install install
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib" cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"

View File

@@ -69,12 +69,65 @@ type tensor struct {
name string name string
kind uint32 kind uint32
offset uint64 offset uint64
size uint64
// shape is the number of elements in each dimension // shape is the number of elements in each dimension
shape [4]uint64 shape [4]uint64
} }
func (t tensor) blockSize() uint64 {
switch {
case t.kind < 2:
return 1
case t.kind < 10:
return 32
default:
return 256
}
}
func (t tensor) typeSize() uint64 {
blockSize := t.blockSize()
switch t.kind {
case 0: // FP32
return 4
case 1: // FP16
return 2
case 2: // Q4_0
return 2 + blockSize/2
case 3: // Q4_1
return 2 + 2 + blockSize/2
case 6: // Q5_0
return 2 + 4 + blockSize/2
case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
return 2 + blockSize
case 9: // Q8_1
return 4 + 4 + blockSize
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
default:
return 0
}
}
func (t tensor) parameters() uint64 {
return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
}
func (t tensor) size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
type ggufModel struct { type ggufModel struct {
*containerGGUF *containerGGUF
@@ -201,61 +254,15 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
shape[i] = llm.readU64(rso) shape[i] = llm.readU64(rso)
} }
kind := llm.readU32(rso) tensor := tensor{
offset := llm.readU64(rso)
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name, name: name,
kind: kind, kind: llm.readU32(rso),
offset: offset, offset: llm.readU64(rso),
size: size,
shape: shape, shape: shape,
}) }
llm.parameters += parameters llm.tensors = append(llm.tensors, tensor)
llm.parameters += tensor.parameters()
} }
alignment, ok := llm.kv["general.alignment"].(uint32) alignment, ok := llm.kv["general.alignment"].(uint32)
@@ -265,7 +272,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent) rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors { for _, tensor := range llm.tensors {
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1) padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent) rso.Seek(padded, io.SeekCurrent)
} }

View File

@@ -62,7 +62,7 @@ const maxRetries = 3
type PredictOpts struct { type PredictOpts struct {
Prompt string Prompt string
Format string Format string
Images []api.ImageData Images []ImageData
Options api.Options Options api.Options
} }

View File

@@ -50,10 +50,10 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value // fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead()) kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
// rough estimation for scratch space based on context size, batch size and number of layers in the model // this amount is the overhead + tensors in memory
// TODO: instead call llama.cpp's alloc functions to measure required memory // TODO: get this from the llama.cpp's graph calculations instead of
// TODO: account for quantization levels // estimating it's 1/6 * kv_cache_size * num_gqa
scratch := 8*int64(opts.NumCtx)*int64(opts.NumBatch)*int64(ggml.NumLayers()) + 1536*1024*1024 // 1536MiB overhead graph := int64(ggml.NumGQA()) * kv / 6
info := gpu.GetGPUInfo() info := gpu.GetGPUInfo()
switch runtime.GOOS { switch runtime.GOOS {
@@ -62,7 +62,7 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
break break
} }
if size+kv+scratch > vram { if size+kv+graph > vram {
slog.Info("not enough vram available, falling back to CPU only") slog.Info("not enough vram available, falling back to CPU only")
info.Library = "cpu" info.Library = "cpu"
info.Variant = gpu.GetCPUVariant() info.Variant = gpu.GetCPUVariant()
@@ -70,7 +70,8 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
break break
} }
opts.NumGPU = 1 // TODO: implement layer splitting on macOS
opts.NumGPU = 999
default: default:
if info.Library == "cpu" { if info.Library == "cpu" {
slog.Info("GPU not available, falling back to CPU") slog.Info("GPU not available, falling back to CPU")
@@ -99,13 +100,13 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
maxlayers := int64(ggml.NumLayers()) + 1 maxlayers := int64(ggml.NumLayers()) + 1
devices := int64(info.DeviceCount) devices := int64(info.DeviceCount)
avg := vram / devices avg := vram / devices
layers := maxlayers * (avg - scratch) / (kv + size/devices) layers := maxlayers * (avg - graph) / (kv + size/devices)
if layers > maxlayers { if layers > maxlayers {
layers = maxlayers layers = maxlayers
} }
// 1 + 2 must fit on the main gpu // 1 + 2 must fit on the main gpu
min := scratch + kv*layers/maxlayers min := graph + kv*layers/maxlayers
if layers <= 0 || min > avg { if layers <= 0 || min > avg {
slog.Info("not enough vram available, falling back to CPU only") slog.Info("not enough vram available, falling back to CPU only")
info.Library = "cpu" info.Library = "cpu"

30
llm/patches/01-cache.diff Normal file
View File

@@ -0,0 +1,30 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index a48582ad..9fffffd8 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -1564,12 +1564,6 @@ struct llama_server_context
LOG_TEE("slot %d : in cache: %i tokens | to process: %i tokens\n", slot.id, slot.n_past, slot.num_prompt_tokens_processed);
}
- LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
-
- llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
-
- slot.cache_tokens = prompt_tokens;
-
if (slot.n_past == slot.num_prompt_tokens && slot.n_past > 0)
{
// we have to evaluate at least 1 token to generate logits.
@@ -1581,6 +1575,12 @@ struct llama_server_context
}
}
+ LOG_TEE("slot %d : kv cache rm - [%d, end)\n", slot.id, (int) system_tokens.size() + slot.n_past);
+
+ llama_kv_cache_seq_rm(ctx, slot.id, system_tokens.size() + slot.n_past, -1);
+
+ slot.cache_tokens = prompt_tokens;
+
LOG_VERBOSE("prompt ingested", {
{"n_past", slot.n_past},
{"cached", tokens_to_str(ctx, slot.cache_tokens.cbegin(), slot.cache_tokens.cbegin() + slot.n_past)},

View File

@@ -0,0 +1,90 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 11dd82c3..311495a8 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -28,6 +28,7 @@
#include <chrono>
#include <condition_variable>
#include <atomic>
+#include <signal.h>
using json = nlohmann::json;
@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
}
}
+std::function<void(int)> shutdown_handler;
+inline void signal_handler(int signal) { shutdown_handler(signal); }
+
int main(int argc, char **argv)
{
#if SERVER_VERBOSE != 1
@@ -3014,8 +3018,14 @@ int main(int argc, char **argv)
std::placeholders::_2,
std::placeholders::_3
));
- llama.queue_tasks.start_loop();
+ shutdown_handler = [&](int) {
+ llama.queue_tasks.terminate();
+ };
+ signal(SIGTERM, signal_handler);
+ signal(SIGINT, signal_handler);
+ llama.queue_tasks.start_loop();
+ svr.stop();
t.join();
llama_backend_free();
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 70cce072..2acb1eab 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -6,6 +6,7 @@
#include <mutex>
#include <condition_variable>
#include <unordered_map>
+#include <atomic>
#include "json.hpp"
@@ -190,6 +191,7 @@ inline std::string format_chatml(std::vector<json> messages)
struct llama_server_queue {
int id = 0;
std::mutex mutex_tasks;
+ std::atomic<bool> running;
// queues
std::vector<task_server> queue_tasks;
std::vector<task_server> queue_tasks_deferred;
@@ -248,9 +250,15 @@ struct llama_server_queue {
queue_tasks_deferred.clear();
}
- // Start the main loop. This call is blocking
- [[noreturn]]
+ // end the start_loop routine
+ void terminate() {
+ running = false;
+ condition_tasks.notify_all();
+ }
+
+ // Start the main loop.
void start_loop() {
+ running = true;
while (true) {
// new task arrived
LOG_VERBOSE("have new task", {});
@@ -294,8 +302,12 @@ struct llama_server_queue {
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
+ if (!running.load()) {
+ LOG_VERBOSE("ending start_loop", {});
+ return;
+ }
condition_tasks.wait(lock, [&]{
- return !queue_tasks.empty();
+ return (!queue_tasks.empty() || !running.load());
});
}
}

View File

@@ -7,6 +7,7 @@ import (
"fmt" "fmt"
"io" "io"
"log/slog" "log/slog"
"slices"
) )
type Command struct { type Command struct {
@@ -56,6 +57,16 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(bytes.TrimSpace(fields[1])) command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED": case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead") return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default: default:
if !bytes.HasPrefix(fields[0], []byte("#")) { if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands // log a warning for unknown commands

View File

@@ -61,3 +61,38 @@ PARAMETER param1
assert.ErrorContains(t, err, "missing value for [param1]") assert.ErrorContains(t, err, "missing value for [param1]")
} }
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}

View File

@@ -133,13 +133,6 @@ func (b *Buffer) Size() int {
return b.Buf.Size() return b.Buf.Size()
} }
func min(n, m int) int {
if n > m {
return m
}
return n
}
func (b *Buffer) Add(r rune) { func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() { if b.Pos == b.Buf.Size() {
fmt.Printf("%c", r) fmt.Printf("%c", r)

View File

@@ -2,7 +2,7 @@
set -e set -e
export VERSION=${VERSION:-0.0.0} export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'" export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
mkdir -p dist mkdir -p dist

View File

@@ -2,7 +2,7 @@
set -eu set -eu
export VERSION=${VERSION:-0.0.0} export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'" export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
docker build \ docker build \
@@ -13,3 +13,13 @@ docker build \
-f Dockerfile \ -f Dockerfile \
-t ollama/ollama:$VERSION \ -t ollama/ollama:$VERSION \
. .
docker build \
--load \
--platform=linux/amd64 \
--build-arg=VERSION \
--build-arg=GOFLAGS \
--target runtime-rocm \
-f Dockerfile \
-t ollama/ollama:$VERSION-rocm \
.

View File

@@ -2,14 +2,24 @@
set -eu set -eu
export VERSION=${VERSION:-0.0.0} export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'" export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"} BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
export AMDGPU_TARGETS=${AMDGPU_TARGETS:=""}
mkdir -p dist mkdir -p dist
for TARGETARCH in ${BUILD_ARCH}; do for TARGETARCH in ${BUILD_ARCH}; do
docker build --platform=linux/$TARGETARCH --build-arg=GOFLAGS --build-arg=CGO_CFLAGS --build-arg=OLLAMA_CUSTOM_CPU_DEFS -f Dockerfile.build -t builder:$TARGETARCH . docker build \
--platform=linux/$TARGETARCH \
--build-arg=GOFLAGS \
--build-arg=CGO_CFLAGS \
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
--build-arg=AMDGPU_TARGETS \
--target build-$TARGETARCH \
-f Dockerfile \
-t builder:$TARGETARCH \
.
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH
docker rm builder-$TARGETARCH docker rm builder-$TARGETARCH

View File

@@ -147,12 +147,7 @@ func (s SignatureData) Bytes() []byte {
// SignData takes a SignatureData object and signs it with a raw private key // SignData takes a SignatureData object and signs it with a raw private key
func (s SignatureData) Sign(rawKey []byte) (string, error) { func (s SignatureData) Sign(rawKey []byte) (string, error) {
privateKey, err := ssh.ParseRawPrivateKey(rawKey) signer, err := ssh.ParsePrivateKey(rawKey)
if err != nil {
return "", err
}
signer, err := ssh.NewSignerFromKey(privateKey)
if err != nil { if err != nil {
return "", err return "", err
} }

View File

@@ -25,6 +25,11 @@ import (
"github.com/jmorganca/ollama/format" "github.com/jmorganca/ollama/format"
) )
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var blobDownloadManager sync.Map var blobDownloadManager sync.Map
type blobDownload struct { type blobDownload struct {
@@ -44,10 +49,11 @@ type blobDownload struct {
} }
type blobDownloadPart struct { type blobDownloadPart struct {
N int N int
Offset int64 Offset int64
Size int64 Size int64
Completed int64 Completed int64
lastUpdated time.Time
*blobDownload `json:"-"` *blobDownload `json:"-"`
} }
@@ -72,6 +78,13 @@ func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size return p.Offset + p.Size
} }
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdated = time.Now()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error { func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*") partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil { if err != nil {
@@ -157,6 +170,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC): case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space // return immediately if the context is canceled or the device is out of space
return err return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil: case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try))) sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)) slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
@@ -195,28 +211,54 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
} }
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error { func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
headers := make(http.Header) g, ctx := errgroup.WithContext(ctx)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1)) g.Go(func() error {
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts) headers := make(http.Header)
if err != nil { headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
return err resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
} if err != nil {
defer resp.Body.Close() return err
}
defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, b)) n, err := io.Copy(w, io.TeeReader(resp.Body, part))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) { if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress // rollback progress
b.Completed.Add(-n) b.Completed.Add(-n)
return err return err
} }
part.Completed += n part.Completed += n
if err := b.writePart(part.Name(), part); err != nil { if err := b.writePart(part.Name(), part); err != nil {
return err return err
} }
// return nil or context.Canceled or UnexpectedEOF (resumable) // return nil or context.Canceled or UnexpectedEOF (resumable)
return err return err
})
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed >= part.Size {
return nil
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
return g.Wait()
} }
func (b *blobDownload) newPart(offset, size int64) error { func (b *blobDownload) newPart(offset, size int64) error {
@@ -255,12 +297,6 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return json.NewEncoder(partFile).Encode(part) return json.NewEncoder(partFile).Encode(part)
} }
func (b *blobDownload) Write(p []byte) (n int, err error) {
n = len(p)
b.Completed.Add(int64(n))
return n, nil
}
func (b *blobDownload) acquire() { func (b *blobDownload) acquire() {
b.references.Add(1) b.references.Add(1)
} }
@@ -279,20 +315,19 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for { for {
select { select {
case <-ticker.C: case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
case <-ctx.Done(): case <-ctx.Done():
return ctx.Err() return ctx.Err()
} }
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
} }
} }
@@ -303,10 +338,6 @@ type downloadOpts struct {
fn func(api.ProgressResponse) fn func(api.ProgressResponse)
} }
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory // downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error { func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest) fp, err := GetBlobsPath(opts.digest)

View File

@@ -41,7 +41,7 @@ type Model struct {
Config ConfigV2 Config ConfigV2
ShortName string ShortName string
ModelPath string ModelPath string
OriginalModel string ParentModel string
AdapterPaths []string AdapterPaths []string
ProjectorPaths []string ProjectorPaths []string
Template string Template string
@@ -50,6 +50,12 @@ type Model struct {
Digest string Digest string
Size int64 Size int64
Options map[string]interface{} Options map[string]interface{}
Messages []Message
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
} }
type PromptVars struct { type PromptVars struct {
@@ -57,6 +63,7 @@ type PromptVars struct {
Prompt string Prompt string
Response string Response string
First bool First bool
Images []llm.ImageData
} }
// extractParts extracts the parts of the template before and after the {{.Response}} node. // extractParts extracts the parts of the template before and after the {{.Response}} node.
@@ -113,10 +120,6 @@ func Prompt(promptTemplate string, p PromptVars) (string, error) {
// PreResponsePrompt returns the prompt before the response tag // PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) { func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
pre, _, err := extractParts(m.Template) pre, _, err := extractParts(m.Template)
if err != nil { if err != nil {
return "", err return "", err
@@ -144,62 +147,68 @@ func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
return Prompt(post, p) return Prompt(post, p)
} }
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) { type ChatHistory struct {
Prompts []PromptVars
LastSystem string
}
// ChatPrompts returns a list of formatted chat prompts from a list of messages
func (m *Model) ChatPrompts(msgs []api.Message) (*ChatHistory, error) {
// build the prompt from the list of messages // build the prompt from the list of messages
var prompt strings.Builder lastSystem := m.System
var currentImages []api.ImageData
currentVars := PromptVars{ currentVars := PromptVars{
First: true, First: true,
System: m.System, System: m.System,
} }
writePrompt := func() error { prompts := []PromptVars{}
p, err := Prompt(m.Template, currentVars) var images []llm.ImageData
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs { for _, msg := range msgs {
switch strings.ToLower(msg.Role) { switch strings.ToLower(msg.Role) {
case "system": case "system":
if currentVars.System != "" { // if this is the first message it overrides the system prompt in the modelfile
if err := writePrompt(); err != nil { if !currentVars.First && currentVars.System != "" {
return "", nil, err prompts = append(prompts, currentVars)
} currentVars = PromptVars{}
} }
currentVars.System = msg.Content currentVars.System = msg.Content
lastSystem = msg.Content
case "user": case "user":
if currentVars.Prompt != "" { if currentVars.Prompt != "" {
if err := writePrompt(); err != nil { prompts = append(prompts, currentVars)
return "", nil, err currentVars = PromptVars{}
}
} }
currentVars.Prompt = msg.Content currentVars.Prompt = msg.Content
currentImages = msg.Images for i := range msg.Images {
id := len(images) + i
currentVars.Prompt += fmt.Sprintf(" [img-%d]", id)
currentVars.Images = append(currentVars.Images, llm.ImageData{
ID: id,
Data: msg.Images[i],
})
}
images = append(images, currentVars.Images...)
case "assistant": case "assistant":
currentVars.Response = msg.Content currentVars.Response = msg.Content
if err := writePrompt(); err != nil { prompts = append(prompts, currentVars)
return "", nil, err currentVars = PromptVars{}
}
default: default:
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role) return nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
} }
} }
// Append the last set of vars if they are non-empty // Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" { if currentVars.Prompt != "" || currentVars.System != "" {
p, err := m.PreResponsePrompt(currentVars) prompts = append(prompts, currentVars)
if err != nil {
return "", nil, fmt.Errorf("pre-response template: %w", err)
}
prompt.WriteString(p)
} }
return prompt.String(), currentImages, nil return &ChatHistory{
Prompts: prompts,
LastSystem: lastSystem,
}, nil
} }
type ManifestV2 struct { type ManifestV2 struct {
@@ -333,7 +342,7 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType { switch layer.MediaType {
case "application/vnd.ollama.image.model": case "application/vnd.ollama.image.model":
model.ModelPath = filename model.ModelPath = filename
model.OriginalModel = layer.From model.ParentModel = layer.From
case "application/vnd.ollama.image.embed": case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2 // Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version // TODO: remove this warning in a future version
@@ -374,6 +383,16 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil { if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err return nil, err
} }
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license": case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename) bts, err := os.ReadFile(filename)
if err != nil { if err != nil {
@@ -412,6 +431,13 @@ func realpath(mfDir, from string) string {
} }
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error { func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
config := ConfigV2{ config := ConfigV2{
OS: "linux", OS: "linux",
Architecture: "amd64", Architecture: "amd64",
@@ -420,15 +446,13 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}, },
} }
deleteMap := make(map[string]struct{})
var layers Layers var layers Layers
messages := []string{}
params := make(map[string][]string) params := make(map[string][]string)
fromParams := make(map[string]any) fromParams := make(map[string]any)
for _, c := range commands { for _, c := range commands {
slog.Info(fmt.Sprintf("[%s] - %s", c.Name, c.Args))
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name) mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name { switch c.Name {
@@ -602,11 +626,37 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
} }
layers.Replace(layer) layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
default: default:
params[c.Name] = append(params[c.Name], c.Args) params[c.Name] = append(params[c.Name], c.Args)
} }
} }
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
if len(params) > 0 { if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"}) fn(api.ProgressResponse{Status: "creating parameters layer"})
@@ -903,8 +953,8 @@ func ShowModelfile(model *Model) (string, error) {
mt.Model = model mt.Model = model
mt.From = model.ModelPath mt.From = model.ModelPath
if model.OriginalModel != "" { if model.ParentModel != "" {
mt.From = model.OriginalModel mt.From = model.ParentModel
} }
modelFile := `# Modelfile generated by "ollama show" modelFile := `# Modelfile generated by "ollama show"

View File

@@ -1,6 +1,7 @@
package server package server
import ( import (
"bytes"
"strings" "strings"
"testing" "testing"
@@ -233,17 +234,58 @@ func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
} }
} }
func chatHistoryEqual(a, b ChatHistory) bool {
if len(a.Prompts) != len(b.Prompts) {
return false
}
for i, v := range a.Prompts {
if v.First != b.Prompts[i].First {
return false
}
if v.Response != b.Prompts[i].Response {
return false
}
if v.Prompt != b.Prompts[i].Prompt {
return false
}
if v.System != b.Prompts[i].System {
return false
}
if len(v.Images) != len(b.Prompts[i].Images) {
return false
}
for j, img := range v.Images {
if img.ID != b.Prompts[i].Images[j].ID {
return false
}
if !bytes.Equal(img.Data, b.Prompts[i].Images[j].Data) {
return false
}
}
}
return a.LastSystem == b.LastSystem
}
func TestChat(t *testing.T) { func TestChat(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
template string model Model
msgs []api.Message msgs []api.Message
want string want ChatHistory
wantErr string wantErr string
}{ }{
{ {
name: "Single Message", name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "system", Role: "system",
@@ -254,34 +296,22 @@ func TestChat(t *testing.T) {
Content: "What are the potion ingredients?", Content: "What are the potion ingredients?",
}, },
}, },
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]", want: ChatHistory{
}, Prompts: []PromptVars{
{ {
name: "First Message", System: "You are a Wizard.",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]", Prompt: "What are the potion ingredients?",
msgs: []api.Message{ First: true,
{ },
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "eye of newt",
},
{
Role: "user",
Content: "Anything else?",
}, },
LastSystem: "You are a Wizard.",
}, },
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
}, },
{ {
name: "Message History", name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "system", Role: "system",
@@ -300,18 +330,85 @@ func TestChat(t *testing.T) {
Content: "Anything else?", Content: "Anything else?",
}, },
}, },
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]", want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
}, },
{ {
name: "Assistant Only", name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]", model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
},
msgs: []api.Message{ msgs: []api.Message{
{ {
Role: "assistant", Role: "assistant",
Content: "everything nice", Content: "everything nice",
}, },
}, },
want: "[INST] [/INST]everything nice", want: ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
},
{
name: "Last system message is preserved from modelfile",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "user",
Content: "hi",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Mojo Jojo.",
Prompt: "hi",
First: true,
},
},
LastSystem: "You are Mojo Jojo.",
},
},
{
name: "Last system message is preserved from messages",
model: Model{
Template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
System: "You are Mojo Jojo.",
},
msgs: []api.Message{
{
Role: "system",
Content: "You are Professor Utonium.",
},
},
want: ChatHistory{
Prompts: []PromptVars{
{
System: "You are Professor Utonium.",
First: true,
},
},
LastSystem: "You are Professor Utonium.",
},
}, },
{ {
name: "Invalid Role", name: "Invalid Role",
@@ -326,11 +423,8 @@ func TestChat(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, _, err := m.ChatPrompt(tt.msgs) got, err := tt.model.ChatPrompts(tt.msgs)
if tt.wantErr != "" { if tt.wantErr != "" {
if err == nil { if err == nil {
t.Errorf("ChatPrompt() expected error, got nil") t.Errorf("ChatPrompt() expected error, got nil")
@@ -338,9 +432,10 @@ func TestChat(t *testing.T) {
if !strings.Contains(err.Error(), tt.wantErr) { if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
} }
return
} }
if got != tt.want { if !chatHistoryEqual(*got, tt.want) {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want) t.Errorf("ChatPrompt() got = %#v, want %#v", got, tt.want)
} }
}) })
} }

View File

@@ -186,7 +186,13 @@ func GenerateHandler(c *gin.Context) {
return return
} }
sessionDuration := defaultSessionDuration var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -233,6 +239,15 @@ func GenerateHandler(c *gin.Context) {
Prompt: req.Prompt, Prompt: req.Prompt,
First: len(req.Context) == 0, First: len(req.Context) == 0,
} }
if promptVars.System == "" {
promptVars.System = model.System
}
for i := range req.Images {
promptVars.Prompt += fmt.Sprintf(" [img-%d]", i)
}
p, err := model.PreResponsePrompt(promptVars) p, err := model.PreResponsePrompt(promptVars)
if err != nil { if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
@@ -242,6 +257,8 @@ func GenerateHandler(c *gin.Context) {
prompt = rebuild.String() prompt = rebuild.String()
} }
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any) ch := make(chan any)
var generated strings.Builder var generated strings.Builder
go func() { go func() {
@@ -295,11 +312,19 @@ func GenerateHandler(c *gin.Context) {
ch <- resp ch <- resp
} }
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction // Start prediction
predictReq := llm.PredictOpts{ predictReq := llm.PredictOpts{
Prompt: prompt, Prompt: prompt,
Format: req.Format, Format: req.Format,
Images: req.Images, Images: images,
Options: opts, Options: opts,
} }
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil { if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -378,7 +403,14 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -659,6 +691,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
} }
modelDetails := api.ModelDetails{ modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat, Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily, Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies, Families: model.Config.ModelFamilies,
@@ -674,11 +707,17 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model.Template = req.Template model.Template = req.Template
} }
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
resp := &api.ShowResponse{ resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"), License: strings.Join(model.License, "\n"),
System: model.System, System: model.System,
Template: model.Template, Template: model.Template,
Details: modelDetails, Details: modelDetails,
Messages: msgs,
} }
var params []string var params []string
@@ -911,13 +950,26 @@ func (s *Server) GenerateRoutes() http.Handler {
} }
func Serve(ln net.Listener) error { func Serve(ln net.Listener) error {
level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" { if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
var programLevel = new(slog.LevelVar) level = slog.LevelDebug
h := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: programLevel, AddSource: true})
slog.SetDefault(slog.New(h))
programLevel.Set(slog.LevelDebug)
slog.Debug("Debug logging enabled")
} }
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" { if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests // clean up unused layers and manifests
if err := PruneLayers(); err != nil { if err := PruneLayers(); err != nil {
@@ -1067,7 +1119,14 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
} }
sessionDuration := defaultSessionDuration
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
if err := load(c, model, opts, sessionDuration); err != nil { if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return return
@@ -1075,18 +1134,32 @@ func ChatHandler(c *gin.Context) {
// an empty request loads the model // an empty request loads the model
if len(req.Messages) == 0 { if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}}) resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return return
} }
checkpointLoaded := time.Now() checkpointLoaded := time.Now()
prompt, images, err := model.ChatPrompt(req.Messages) chat, err := model.ChatPrompts(req.Messages)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return return
} }
prompt, images, err := trimmedPrompt(c.Request.Context(), chat, model)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
slog.Debug("chat handler", "prompt", prompt)
ch := make(chan any) ch := make(chan any)
go func() { go func() {
@@ -1160,3 +1233,115 @@ func ChatHandler(c *gin.Context) {
streamResponse(c, ch) streamResponse(c, ch)
} }
// promptInfo stores the variables used to template a prompt, and the token length of the resulting template for some model
type promptInfo struct {
vars PromptVars
tokenLen int
}
// trimmedPrompt builds a prompt to send to a running model. It ensures the prompt fits within the max context length,
// while preserving the most recent system message.
func trimmedPrompt(ctx context.Context, chat *ChatHistory, model *Model) (string, []llm.ImageData, error) {
if len(chat.Prompts) == 0 {
return "", nil, nil
}
var promptsToAdd []promptInfo
var totalTokenLength int
var systemPromptIncluded bool
var images []llm.ImageData
// reverse iterate through the prompts to build the prompt string in a way that fits the max context length
for i := len(chat.Prompts) - 1; i >= 0; i-- {
prompt := chat.Prompts[i]
promptText, err := promptString(model, prompt, i == len(chat.Prompts)-1)
if err != nil {
return "", nil, err
}
encodedTokens, err := loaded.runner.Encode(ctx, promptText)
if err != nil {
return "", nil, err
}
if totalTokenLength+len(encodedTokens) > loaded.NumCtx && i != len(chat.Prompts)-1 {
break // reached max context length, stop adding more prompts
}
for j := range prompt.Images {
if totalTokenLength+768 > loaded.NumCtx {
// this decreases the token length but overestimating is fine
prompt.Prompt = strings.ReplaceAll(prompt.Prompt, fmt.Sprintf(" [img-%d]", prompt.Images[j].ID), "")
continue
}
totalTokenLength += 768
images = append(images, prompt.Images[j])
}
totalTokenLength += len(encodedTokens)
systemPromptIncluded = systemPromptIncluded || prompt.System != ""
promptsToAdd = append(promptsToAdd, promptInfo{vars: prompt, tokenLen: len(encodedTokens)})
}
// ensure the system prompt is included, if not already
if chat.LastSystem != "" && !systemPromptIncluded {
var err error
promptsToAdd, err = includeSystemPrompt(ctx, chat.LastSystem, totalTokenLength, promptsToAdd)
if err != nil {
return "", nil, err
}
}
promptsToAdd[len(promptsToAdd)-1].vars.First = true
// construct the final prompt string from the prompts which fit within the context window
var result string
for i, prompt := range promptsToAdd {
promptText, err := promptString(model, prompt.vars, i == 0)
if err != nil {
return "", nil, err
}
result = promptText + result
}
return result, images, nil
}
// promptString applies the model template to the prompt
func promptString(model *Model, vars PromptVars, isMostRecent bool) (string, error) {
if isMostRecent {
p, err := model.PreResponsePrompt(vars)
if err != nil {
return "", fmt.Errorf("pre-response template: %w", err)
}
return p, nil
}
p, err := Prompt(model.Template, vars)
if err != nil {
return "", err
}
return p, nil
}
// includeSystemPrompt adjusts the prompts to include the system prompt.
func includeSystemPrompt(ctx context.Context, systemPrompt string, totalTokenLength int, promptsToAdd []promptInfo) ([]promptInfo, error) {
systemTokens, err := loaded.runner.Encode(ctx, systemPrompt)
if err != nil {
return nil, err
}
for i := len(promptsToAdd) - 1; i >= 0; i-- {
if totalTokenLength+len(systemTokens) <= loaded.NumCtx {
promptsToAdd[i].vars.System = systemPrompt
return promptsToAdd[:i+1], nil
}
totalTokenLength -= promptsToAdd[i].tokenLen
}
// if got here, system did not fit anywhere, so return the most recent prompt with the system message set
recent := promptsToAdd[len(promptsToAdd)-1]
recent.vars.System = systemPrompt
return []promptInfo{recent}, nil
}

View File

@@ -16,6 +16,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api" "github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser" "github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version" "github.com/jmorganca/ollama/version"
) )
@@ -239,3 +240,258 @@ func Test_Routes(t *testing.T) {
} }
} }
func Test_ChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
chat *ChatHistory
numCtx int
runner MockLLM
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
First: true,
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "eye of newt",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
},
},
LastSystem: "You are a Wizard.",
},
numCtx: 4,
runner: MockLLM{
encoding: []int{1}, // fit the ctxLen, 1 for each message
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Response: "everything nice",
First: true,
},
},
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Message History Truncated, No System",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the potion ingredients?",
Response: "sugar",
First: true,
},
{
Prompt: "Anything else?",
Response: "spice",
},
{
Prompt: "... and?",
},
},
},
numCtx: 2, // only 1 message from history and most recent message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] Anything else? [/INST]spice[INST] ... and? [/INST]",
},
{
name: "System is Preserved when Truncated",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 2,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "System is Preserved when Length Exceeded",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What are the magic words?",
Response: "abracadabra",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 1,
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. What is the spell for invisibility? [/INST]",
},
{
name: "First is Preserved when Truncated",
template: "[INST] {{ if .First }}{{ .System }} {{ end }}{{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
// first message omitted for test
{
Prompt: "Do you have a magic hat?",
Response: "Of course.",
},
{
Prompt: "What is the spell for invisibility?",
},
},
LastSystem: "You are a wizard.",
},
numCtx: 3, // two most recent messages and room for system message
runner: MockLLM{
encoding: []int{1},
},
want: "[INST] You are a wizard. Do you have a magic hat? [/INST]Of course.[INST] What is the spell for invisibility? [/INST]",
},
{
name: "Most recent message is returned when longer than ctxLen",
template: "[INST] {{ .Prompt }} [/INST]",
chat: &ChatHistory{
Prompts: []PromptVars{
{
Prompt: "What is the spell for invisibility?",
First: true,
},
},
},
numCtx: 1, // two most recent messages
runner: MockLLM{
encoding: []int{1, 2},
},
want: "[INST] What is the spell for invisibility? [/INST]",
},
}
for _, testCase := range tests {
tt := testCase
m := &Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
loaded.runner = &tt.runner
loaded.Options = &api.Options{
Runner: api.Runner{
NumCtx: tt.numCtx,
},
}
// TODO: add tests for trimming images
got, _, err := trimmedPrompt(context.Background(), tt.chat, m)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}
type MockLLM struct {
encoding []int
}
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
return nil
}
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
return llm.encoding, nil
}
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
return "", nil
}
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
return []float64{}, nil
}
func (llm *MockLLM) Close() {
// do nothing
}