Compare commits

..

1 Commits

Author SHA1 Message Date
Jeffrey Morgan
be721ca0df add more search paths for cuda libs 2024-01-10 09:51:02 -05:00
96 changed files with 2543 additions and 4585 deletions

View File

@@ -1,162 +0,0 @@
name: test
on:
pull_request:
jobs:
generate:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
arch: [amd64, arm64]
exclude:
- os: ubuntu-latest
arch: arm64
- os: windows-latest
arch: arm64
runs-on: ${{ matrix.os }}
env:
GOARCH: ${{ matrix.arch }}
steps:
- uses: actions/checkout@v4
- uses: actions/setup-go@v5
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: go generate -x ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
path: llm/llama.cpp/build/**/lib/*
generate-cuda:
strategy:
matrix:
cuda-version:
- '11.8.0'
runs-on: linux
container: nvidia/cuda:${{ matrix.cuda-version }}-devel-ubuntu20.04
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: cuda-${{ matrix.cuda-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
generate-rocm:
strategy:
matrix:
rocm-version:
- '5.7.1'
- '6.0'
runs-on: linux
container: rocm/dev-ubuntu-20.04:${{ matrix.rocm-version }}
steps:
- run: |
apt-get update && apt-get install -y git build-essential curl rocm-libs
curl -fsSL https://github.com/Kitware/CMake/releases/download/v3.28.1/cmake-3.28.1-linux-x86_64.tar.gz \
| tar -zx -C /usr --strip-components 1
env:
DEBIAN_FRONTEND: noninteractive
- uses: actions/checkout@v4
- uses: actions/setup-go@v4
with:
go-version: '1.21'
cache: true
- run: go get ./...
- run: |
git config --global --add safe.directory /__w/ollama/ollama
go generate -x ./...
env:
OLLAMA_SKIP_CPU_GENERATE: '1'
- uses: actions/upload-artifact@v4
with:
name: rocm-${{ matrix.rocm-version }}-libraries
path: llm/llama.cpp/build/**/lib/*
lint:
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
arch: [amd64, arm64]
exclude:
- os: ubuntu-latest
arch: arm64
- os: windows-latest
arch: arm64
- os: macos-latest
arch: amd64
runs-on: ${{ matrix.os }}
env:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: "1"
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- uses: actions/setup-go@v5
with:
go-version: '1.21'
cache: false
- run: |
mkdir -p llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/linux/${{ matrix.arch }}/stub/lib/stub.so
if: ${{ startsWith(matrix.os, 'ubuntu-') }}
- run: |
mkdir -p llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/darwin/${{ matrix.arch }}/stub/lib/stub.dylib
touch llm/llama.cpp/ggml-metal.metal
if: ${{ startsWith(matrix.os, 'macos-') }}
- run: |
mkdir -p llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/
touch llm/llama.cpp/build/windows/${{ matrix.arch }}/stub/lib/stub.dll
if: ${{ startsWith(matrix.os, 'windows-') }}
- uses: golangci/golangci-lint-action@v3
test:
needs: generate
strategy:
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
arch: [amd64]
exclude:
- os: ubuntu-latest
arch: arm64
- os: windows-latest
arch: arm64
runs-on: ${{ matrix.os }}
env:
GOARCH: ${{ matrix.arch }}
CGO_ENABLED: "1"
steps:
- uses: actions/checkout@v4
with:
submodules: recursive
- uses: actions/setup-go@v5
with:
go-version: '1.21'
cache: true
- run: go get
- uses: actions/download-artifact@v4
with:
name: ${{ matrix.os }}-${{ matrix.arch }}-libraries
path: llm/llama.cpp/build
- run: go build
- run: go test -v ./...
- uses: actions/upload-artifact@v4
with:
name: ${{ matrix.os }}-binaries
path: ollama

View File

@@ -1,27 +0,0 @@
run:
timeout: 5m
linters:
enable:
- asasalint
- bidichk
- bodyclose
- containedctx
- contextcheck
- exportloopref
- gocheckcompilerdirectives
# FIXME: for some reason this errors on windows
# - gofmt
# - goimports
- misspell
- nilerr
- unused
linters-settings:
errcheck:
# exclude the following functions since we don't generally
# need to be concerned with the returned errors
exclude-functions:
- encoding/binary.Read
- (*os.File).Seek
- (*bufio.Writer).WriteString
- (*github.com/spf13/pflag.FlagSet).Set
- (*github.com/jmorganca/ollama/llm.readSeekOffset).Seek

View File

@@ -1,135 +1,27 @@
ARG GOLANG_VERSION=1.21.3
ARG CMAKE_VERSION=3.22.1
ARG CUDA_VERSION=11.3.1
FROM nvidia/cuda:11.8.0-devel-ubuntu22.04
# Copy the minimal context we need to run the generate scripts
FROM scratch AS llm-code
COPY .git .git
COPY .gitmodules .gitmodules
COPY llm llm
ARG TARGETARCH
ARG GOFLAGS="'-ldflags=-w -s'"
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/gcc-toolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete AS rocm-5-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:6.0-complete AS rocm-6-build-amd64
ARG CMAKE_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
ARG CGO_CFLAGS
ARG AMDGPU_TARGETS
RUN OLLAMA_SKIP_CPU_GENERATE=1 sh gen_linux.sh
FROM --platform=linux/amd64 centos:7 AS cpu-builder-amd64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu-build-amd64
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx" sh gen_linux.sh
FROM --platform=linux/amd64 cpu-builder-amd64 AS cpu_avx2-build-amd64
RUN OLLAMA_CPU_TARGET="cpu_avx2" sh gen_linux.sh
FROM --platform=linux/arm64 centos:7 AS cpu-build-arm64
ARG CMAKE_VERSION
ARG GOLANG_VERSION
COPY ./scripts/rh_linux_deps.sh /
RUN CMAKE_VERSION=${CMAKE_VERSION} GOLANG_VERSION=${GOLANG_VERSION} sh /rh_linux_deps.sh
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
COPY --from=llm-code / /go/src/github.com/jmorganca/ollama/
WORKDIR /go/src/github.com/jmorganca/ollama/llm/generate
# Note, we only build the "base" CPU variant on arm since avx/avx2 are x86 features
ARG OLLAMA_CUSTOM_CPU_DEFS
ARG CGO_CFLAGS
RUN OLLAMA_CPU_TARGET="cpu" sh gen_linux.sh
# Intermediate stage used for ./scripts/build_linux.sh
FROM --platform=linux/amd64 cpu-build-amd64 AS build-amd64
ENV CGO_ENABLED 1
WORKDIR /go/src/github.com/jmorganca/ollama
RUN apt-get update && apt-get install -y git build-essential cmake
ADD https://dl.google.com/go/go1.21.3.linux-$TARGETARCH.tar.gz /tmp/go1.21.3.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go1.21.3.tar.gz
COPY . .
COPY --from=cpu_avx-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=cpu_avx2-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=rocm-5-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
COPY --from=rocm-6-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
ARG GOFLAGS
ARG CGO_CFLAGS
RUN go build .
ENV GOARCH=$TARGETARCH
ENV GOFLAGS=$GOFLAGS
RUN /usr/local/go/bin/go generate ./... \
&& /usr/local/go/bin/go build .
# Intermediate stage used for ./scripts/build_linux.sh
FROM --platform=linux/arm64 cpu-build-arm64 AS build-arm64
ENV CGO_ENABLED 1
ARG GOLANG_VERSION
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/ llm/llama.cpp/build/linux/
ARG GOFLAGS
ARG CGO_CFLAGS
RUN go build .
# Runtime stages
FROM --platform=linux/amd64 ubuntu:22.04 as runtime-amd64
FROM ubuntu:22.04
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
FROM --platform=linux/arm64 ubuntu:22.04 as runtime-arm64
RUN apt-get update && apt-get install -y ca-certificates
COPY --from=build-arm64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
# Radeon images are much larger so we keep it distinct from the CPU/CUDA image
FROM --platform=linux/amd64 rocm/dev-centos-7:5.7.1-complete as runtime-rocm
RUN update-pciids
COPY --from=build-amd64 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
COPY --from=0 /go/src/github.com/jmorganca/ollama/ollama /bin/ollama
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM runtime-$TARGETARCH
EXPOSE 11434
ENV OLLAMA_HOST 0.0.0.0
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
# set some environment variable for better NVIDIA compatibility
ENV PATH=/usr/local/nvidia/bin:/usr/local/cuda/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility

101
Dockerfile.build Normal file
View File

@@ -0,0 +1,101 @@
ARG GOLANG_VERSION=1.21.3
ARG CMAKE_VERSION=3.22.1
ARG CUDA_VERSION=11.3.1
ARG ROCM_VERSION=5.7.1
FROM --platform=linux/amd64 nvidia/cuda:$CUDA_VERSION-devel-centos7 AS cuda-build-amd64
ARG CMAKE_VERSION
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl \
&& yum update -y \
&& yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ADD https://github.com/Kitware/CMake/releases/download/v$CMAKE_VERSION/cmake-$CMAKE_VERSION-linux-x86_64.tar.gz /tmp/cmake-$CMAKE_VERSION.tar.gz
RUN tar -zx -C /usr --strip-components 1 </tmp/cmake-$CMAKE_VERSION.tar.gz
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
WORKDIR llm/generate
RUN sh gen_linux.sh
FROM --platform=linux/arm64 nvidia/cuda:$CUDA_VERSION-devel-rockylinux8 AS cuda-build-arm64
ARG CMAKE_VERSION
RUN dnf install -y git cmake
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
WORKDIR llm/generate
RUN sh gen_linux.sh
FROM --platform=linux/amd64 rocm/dev-centos-7:$ROCM_VERSION-complete AS rocm-build-amd64
ARG CMAKE_VERSION
RUN yum install -y https://repo.ius.io/ius-release-el7.rpm centos-release-scl \
&& yum update -y \
&& yum remove -y git \
&& yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++ git236
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ENV LIBRARY_PATH /opt/amdgpu/lib64
ADD https://github.com/Kitware/CMake/releases/download/v$CMAKE_VERSION/cmake-$CMAKE_VERSION-linux-x86_64.tar.gz /tmp/cmake-$CMAKE_VERSION.tar.gz
RUN tar -zx -C /usr --strip-components 1 </tmp/cmake-$CMAKE_VERSION.tar.gz
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
WORKDIR llm/generate
RUN sh gen_linux.sh
FROM --platform=linux/amd64 centos:7 AS build-amd64
ENV CGO_ENABLED 1
ARG GOLANG_VERSION
ARG GOFLAGS
ARG CGO_FLAGS
RUN yum install -y centos-release-scl \
&& yum update -y \
&& yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ADD https://dl.google.com/go/go$GOLANG_VERSION.linux-amd64.tar.gz /tmp/go-$GOLANG_VERSION.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go-$GOLANG_VERSION.tar.gz
ENV PATH /usr/local/go/bin:$PATH
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/cpu/lib llm/llama.cpp/build/linux/cpu/lib
COPY --from=cuda-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/cuda/lib llm/llama.cpp/build/linux/cuda/lib
COPY --from=rocm-build-amd64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/rocm/lib llm/llama.cpp/build/linux/rocm/lib
RUN go build .
FROM --platform=linux/arm64 centos:7 AS build-arm64
ENV CGO_ENABLED 1
ARG GOLANG_VERSION
ARG GOFLAGS
ARG CGO_FLAGS
RUN yum install -y centos-release-scl \
&& yum update -y \
&& yum install -y devtoolset-10-gcc devtoolset-10-gcc-c++
ENV PATH /opt/rh/devtoolset-10/root/usr/bin:$PATH
ADD https://dl.google.com/go/go$GOLANG_VERSION.linux-arm64.tar.gz /tmp/go-$GOLANG_VERSION.tar.gz
RUN mkdir -p /usr/local && tar xz -C /usr/local </tmp/go-$GOLANG_VERSION.tar.gz
ENV PATH /usr/local/go/bin:$PATH
WORKDIR /go/src/github.com/jmorganca/ollama
COPY . .
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/cpu/lib llm/llama.cpp/build/linux/cpu/lib
COPY --from=cuda-build-arm64 /go/src/github.com/jmorganca/ollama/llm/llama.cpp/build/linux/cuda/lib llm/llama.cpp/build/linux/cuda/lib
RUN go build .
FROM build-$TARGETARCH

View File

@@ -1,5 +1,8 @@
<div align="center">
<img alt="ollama" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
<picture>
<source media="(prefers-color-scheme: dark)" height="200px" srcset="https://github.com/jmorganca/ollama/assets/3325447/56ea1849-1284-4645-8970-956de6e51c3c">
<img alt="logo" height="200px" src="https://github.com/jmorganca/ollama/assets/3325447/0d0b44e2-8f4a-4e99-9b52-a5c1c741c8f7">
</picture>
</div>
# Ollama
@@ -10,7 +13,7 @@ Get up and running with large language models locally.
### macOS
[Download](https://ollama.com/download/Ollama-darwin.zip)
[Download](https://ollama.ai/download/Ollama-darwin.zip)
### Windows
@@ -19,7 +22,7 @@ Coming soon! For now, you can install Ollama on Windows via WSL2.
### Linux & WSL2
```
curl -fsSL https://ollama.com/install.sh | sh
curl https://ollama.ai/install.sh | sh
```
[Manual install instructions](https://github.com/jmorganca/ollama/blob/main/docs/linux.md)
@@ -28,14 +31,9 @@ curl -fsSL https://ollama.com/install.sh | sh
The official [Ollama Docker image](https://hub.docker.com/r/ollama/ollama) `ollama/ollama` is available on Docker Hub.
### Libraries
- [ollama-python](https://github.com/ollama/ollama-python)
- [ollama-js](https://github.com/ollama/ollama-js)
## Quickstart
To run and chat with [Llama 2](https://ollama.com/library/llama2):
To run and chat with [Llama 2](https://ollama.ai/library/llama2):
```
ollama run llama2
@@ -43,7 +41,7 @@ ollama run llama2
## Model library
Ollama supports a list of open-source models available on [ollama.com/library](https://ollama.com/library 'ollama model library')
Ollama supports a list of open-source models available on [ollama.ai/library](https://ollama.ai/library 'ollama model library')
Here are some example open-source models that can be downloaded:
@@ -200,21 +198,18 @@ brew install cmake go
```
Then generate dependencies:
```
go generate ./...
```
Then build the binary:
```
go build .
```
More detailed instructions can be found in the [developer guide](https://github.com/jmorganca/ollama/blob/main/docs/development.md)
### Running local builds
### Running local builds
Next, start the server:
```
@@ -256,7 +251,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
## Community Integrations
### Web & Desktop
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -269,7 +263,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Amica](https://github.com/semperai/amica)
- [chatd](https://github.com/BruceMacD/chatd)
- [Ollama-SwiftUI](https://github.com/kghandour/Ollama-SwiftUI)
- [MindMac](https://mindmac.app)
### Terminal
@@ -282,7 +276,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [gptel Emacs client](https://github.com/karthink/gptel)
- [Oatmeal](https://github.com/dustinblackman/oatmeal)
- [cmdh](https://github.com/pgibler/cmdh)
- [llm-ollama](https://github.com/taketwo/llm-ollama) for [Datasette's LLM CLI](https://llm.datasette.io/en/stable/).
### Database
@@ -307,10 +300,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Ollama for Dart](https://github.com/breitburg/dart-ollama)
- [Ollama for Laravel](https://github.com/cloudstudio/ollama-laravel)
- [LangChainDart](https://github.com/davidmigloz/langchain_dart)
- [Semantic Kernel - Python](https://github.com/microsoft/semantic-kernel/tree/main/python/semantic_kernel/connectors/ai/ollama)
- [Haystack](https://github.com/deepset-ai/haystack-integrations/blob/main/integrations/ollama.md)
- [Ollama for R - rollama](https://github.com/JBGruber/rollama)
- [Ollama-ex for Elixir](https://github.com/lebrunel/ollama-ex)
### Mobile
@@ -331,6 +320,3 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Rivet plugin](https://github.com/abrenneke/rivet-plugin-ollama)
- [Llama Coder](https://github.com/ex3ndr/llama-coder) (Copilot alternative using Ollama)
- [Obsidian BMO Chatbot plugin](https://github.com/longy2k/obsidian-bmo-chatbot)
- [Open Interpreter](https://docs.openinterpreter.com/language-model-setup/local-models/ollama)
- [twinny](https://github.com/rjmacarthy/twinny) (Copilot and Copilot chat alternative using Ollama)
- [Wingman-AI](https://github.com/RussellCanfield/wingman-ai) (Copilot code and chat alternative using Ollama and HuggingFace)

284
api/client.py Normal file
View File

@@ -0,0 +1,284 @@
import os
import json
import requests
import os
import hashlib
import json
from pathlib import Path
BASE_URL = os.environ.get('OLLAMA_HOST', 'http://localhost:11434')
# Generate a response for a given prompt with a provided model. This is a streaming endpoint, so will be a series of responses.
# The final response object will include statistics and additional data from the request. Use the callback function to override
# the default handler.
def generate(model_name, prompt, system=None, template=None, format="", context=None, options=None, callback=None):
try:
url = f"{BASE_URL}/api/generate"
payload = {
"model": model_name,
"prompt": prompt,
"system": system,
"template": template,
"context": context,
"options": options,
"format": format,
}
# Remove keys with None values
payload = {k: v for k, v in payload.items() if v is not None}
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Creating a variable to hold the context history of the final chunk
final_context = None
# Variable to hold concatenated response strings if no callback is provided
full_response = ""
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# If this is not the last chunk, add the "response" field value to full_response and print it
if not chunk.get("done"):
response_piece = chunk.get("response", "")
full_response += response_piece
print(response_piece, end="", flush=True)
# Check if it's the last chunk (done is true)
if chunk.get("done"):
final_context = chunk.get("context")
# Return the full response and the final context
return full_response, final_context
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None, None
# Create a blob file on the server if it doesn't exist.
def create_blob(digest, file_path):
url = f"{BASE_URL}/api/blobs/{digest}"
# Check if the blob exists
response = requests.head(url)
if response.status_code != 404:
return # Blob already exists, no need to upload
response.raise_for_status()
# Upload the blob
with open(file_path, 'rb') as file_data:
requests.post(url, data=file_data)
# Create a model from a Modelfile. Use the callback function to override the default handler.
def create(model_name, filename, callback=None):
try:
file_path = Path(filename).expanduser().resolve()
processed_lines = []
# Read and process the modelfile
with open(file_path, 'r') as f:
for line in f:
# Skip empty or whitespace-only lines
if not line.strip():
continue
command, args = line.split(maxsplit=1)
if command.upper() in ["FROM", "ADAPTER"]:
path = Path(args.strip()).expanduser()
# Check if path is relative and resolve it
if not path.is_absolute():
path = (file_path.parent / path)
# Skip if file does not exist for "model", this is handled by the server
if not path.exists():
processed_lines.append(line)
continue
# Calculate SHA-256 hash
with open(path, 'rb') as bin_file:
hash = hashlib.sha256()
hash.update(bin_file.read())
blob = f"sha256:{hash.hexdigest()}"
# Add the file to the remote server
create_blob(blob, path)
# Replace path with digest in the line
line = f"{command} @{blob}\n"
processed_lines.append(line)
# Combine processed lines back into a single string
modelfile_content = '\n'.join(processed_lines)
url = f"{BASE_URL}/api/create"
payload = {"name": model_name, "modelfile": modelfile_content}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the status
for line in response.iter_lines():
if line:
chunk = json.loads(line)
if callback:
callback(chunk)
else:
print(f"Status: {chunk.get('status')}")
except Exception as e:
print(f"An error occurred: {e}")
# Pull a model from a the model registry. Cancelled pulls are resumed from where they left off, and multiple
# calls to will share the same download progress. Use the callback function to override the default handler.
def pull(model_name, insecure=False, callback=None):
try:
url = f"{BASE_URL}/api/pull"
payload = {
"name": model_name,
"insecure": insecure
}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# Print the status message directly to the console
print(chunk.get('status', ''), end='', flush=True)
# If there's layer data, you might also want to print that (adjust as necessary)
if 'digest' in chunk:
print(f" - Digest: {chunk['digest']}", end='', flush=True)
print(f" - Total: {chunk['total']}", end='', flush=True)
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
else:
print()
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
# Push a model to the model registry. Use the callback function to override the default handler.
def push(model_name, insecure=False, callback=None):
try:
url = f"{BASE_URL}/api/push"
payload = {
"name": model_name,
"insecure": insecure
}
# Making a POST request with the stream parameter set to True to handle streaming responses
with requests.post(url, json=payload, stream=True) as response:
response.raise_for_status()
# Iterating over the response line by line and displaying the details
for line in response.iter_lines():
if line:
# Parsing each line (JSON chunk) and extracting the details
chunk = json.loads(line)
# If a callback function is provided, call it with the chunk
if callback:
callback(chunk)
else:
# Print the status message directly to the console
print(chunk.get('status', ''), end='', flush=True)
# If there's layer data, you might also want to print that (adjust as necessary)
if 'digest' in chunk:
print(f" - Digest: {chunk['digest']}", end='', flush=True)
print(f" - Total: {chunk['total']}", end='', flush=True)
print(f" - Completed: {chunk['completed']}", end='\n', flush=True)
else:
print()
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
# List models that are available locally.
def list():
try:
response = requests.get(f"{BASE_URL}/api/tags")
response.raise_for_status()
data = response.json()
models = data.get('models', [])
return models
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Copy a model. Creates a model with another name from an existing model.
def copy(source, destination):
try:
# Create the JSON payload
payload = {
"source": source,
"destination": destination
}
response = requests.post(f"{BASE_URL}/api/copy", json=payload)
response.raise_for_status()
# If the request was successful, return a message indicating that the copy was successful
return "Copy successful"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Delete a model and its data.
def delete(model_name):
try:
url = f"{BASE_URL}/api/delete"
payload = {"name": model_name}
response = requests.delete(url, json=payload)
response.raise_for_status()
return "Delete successful"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
# Show info about a model.
def show(model_name):
try:
url = f"{BASE_URL}/api/show"
payload = {"name": model_name}
response = requests.post(url, json=payload)
response.raise_for_status()
# Parse the JSON response and return it
data = response.json()
return data
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return None
def heartbeat():
try:
url = f"{BASE_URL}/"
response = requests.head(url)
response.raise_for_status()
return "Ollama is running"
except requests.exceptions.RequestException as e:
print(f"An error occurred: {e}")
return "Ollama is not running"

View File

@@ -34,26 +34,24 @@ func (e StatusError) Error() string {
type ImageData []byte
type GenerateRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Images []ImageData `json:"images,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
System string `json:"system"`
Template string `json:"template"`
Context []int `json:"context,omitempty"`
Stream *bool `json:"stream,omitempty"`
Raw bool `json:"raw,omitempty"`
Format string `json:"format"`
Images []ImageData `json:"images,omitempty"`
Options map[string]interface{} `json:"options"`
}
type ChatRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream *bool `json:"stream,omitempty"`
Format string `json:"format"`
Options map[string]interface{} `json:"options"`
}
@@ -128,9 +126,8 @@ type Runner struct {
}
type EmbeddingRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
KeepAlive *Duration `json:"keep_alive,omitempty"`
Model string `json:"model"`
Prompt string `json:"prompt"`
Options map[string]interface{} `json:"options"`
}
@@ -140,31 +137,23 @@ type EmbeddingResponse struct {
}
type CreateRequest struct {
Model string `json:"model"`
Name string `json:"name"`
Path string `json:"path"`
Modelfile string `json:"modelfile"`
Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
}
type DeleteRequest struct {
Model string `json:"model"`
// Name is deprecated, see Model
Name string `json:"name"`
}
type ShowRequest struct {
Name string `json:"name"`
Model string `json:"model"`
System string `json:"system"`
Template string `json:"template"`
Options map[string]interface{} `json:"options"`
// Name is deprecated, see Model
Name string `json:"name"`
}
type ShowResponse struct {
@@ -174,7 +163,6 @@ type ShowResponse struct {
Template string `json:"template,omitempty"`
System string `json:"system,omitempty"`
Details ModelDetails `json:"details,omitempty"`
Messages []Message `json:"messages,omitempty"`
}
type CopyRequest struct {
@@ -183,14 +171,11 @@ type CopyRequest struct {
}
type PullRequest struct {
Model string `json:"model"`
Name string `json:"name"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
}
type ProgressResponse struct {
@@ -201,14 +186,11 @@ type ProgressResponse struct {
}
type PushRequest struct {
Model string `json:"model"`
Name string `json:"name"`
Insecure bool `json:"insecure,omitempty"`
Username string `json:"username"`
Password string `json:"password"`
Stream *bool `json:"stream,omitempty"`
// Name is deprecated, see Model
Name string `json:"name"`
}
type ListResponse struct {
@@ -217,7 +199,6 @@ type ListResponse struct {
type ModelResponse struct {
Name string `json:"name"`
Model string `json:"model"`
ModifiedAt time.Time `json:"modified_at"`
Size int64 `json:"size"`
Digest string `json:"digest"`
@@ -240,7 +221,6 @@ type GenerateResponse struct {
}
type ModelDetails struct {
ParentModel string `json:"parent_model"`
Format string `json:"format"`
Family string `json:"family"`
Families []string `json:"families"`
@@ -415,18 +395,15 @@ func (d *Duration) UnmarshalJSON(b []byte) (err error) {
switch t := v.(type) {
case float64:
if t < 0 {
d.Duration = time.Duration(math.MaxInt64)
} else {
d.Duration = time.Duration(t * float64(time.Second))
t = math.MaxFloat64
}
d.Duration = time.Duration(t)
case string:
d.Duration, err = time.ParseDuration(t)
if err != nil {
return err
}
if d.Duration < 0 {
d.Duration = time.Duration(math.MaxInt64)
}
}
return nil

View File

@@ -25,7 +25,6 @@ import (
"github.com/olekukonko/tablewriter"
"github.com/spf13/cobra"
"golang.org/x/crypto/ssh"
"golang.org/x/exp/slices"
"golang.org/x/term"
"github.com/jmorganca/ollama/api"
@@ -36,6 +35,8 @@ import (
"github.com/jmorganca/ollama/version"
)
type ImageData []byte
func CreateHandler(cmd *cobra.Command, args []string) error {
filename, _ := cmd.Flags().GetString("file")
filename, err := filepath.Abs(filename)
@@ -147,68 +148,19 @@ func RunHandler(cmd *cobra.Command, args []string) error {
}
name := args[0]
// check if the model exists on the server
show, err := client.Show(cmd.Context(), &api.ShowRequest{Name: name})
_, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
var statusError api.StatusError
switch {
case errors.As(err, &statusError) && statusError.StatusCode == http.StatusNotFound:
if err := PullHandler(cmd, []string{name}); err != nil {
return err
}
show, err = client.Show(cmd.Context(), &api.ShowRequest{Name: name})
if err != nil {
return err
}
case err != nil:
return err
}
interactive := true
opts := runOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
MultiModal: slices.Contains(show.Details.Families, "clip"),
ParentModel: show.Details.ParentModel,
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
return generateInteractive(cmd, opts)
return RunGenerate(cmd, args)
}
func PushHandler(cmd *cobra.Command, args []string) error {
@@ -460,139 +412,66 @@ func PullHandler(cmd *cobra.Command, args []string) error {
return nil
}
func RunGenerate(cmd *cobra.Command, args []string) error {
interactive := true
opts := generateOptions{
Model: args[0],
WordWrap: os.Getenv("TERM") == "xterm-256color",
Options: map[string]interface{}{},
Images: []ImageData{},
}
format, err := cmd.Flags().GetString("format")
if err != nil {
return err
}
opts.Format = format
prompts := args[1:]
// prepend stdin to the prompt if provided
if !term.IsTerminal(int(os.Stdin.Fd())) {
in, err := io.ReadAll(os.Stdin)
if err != nil {
return err
}
prompts = append([]string{string(in)}, prompts...)
opts.WordWrap = false
interactive = false
}
opts.Prompt = strings.Join(prompts, " ")
if len(prompts) > 0 {
interactive = false
}
nowrap, err := cmd.Flags().GetBool("nowordwrap")
if err != nil {
return err
}
opts.WordWrap = !nowrap
if !interactive {
return generate(cmd, opts)
}
return generateInteractive(cmd, opts)
}
type generateContextKey string
type runOptions struct {
Model string
ParentModel string
Prompt string
Messages []api.Message
WordWrap bool
Format string
System string
Template string
Images []api.ImageData
Options map[string]interface{}
MultiModal bool
type generateOptions struct {
Model string
Prompt string
WordWrap bool
Format string
System string
Template string
Images []ImageData
Options map[string]interface{}
}
type displayResponseState struct {
lineLength int
wordBuffer string
}
func displayResponse(content string, wordWrap bool, state *displayResponseState) {
termWidth, _, _ := term.GetSize(int(os.Stdout.Fd()))
if wordWrap && termWidth >= 10 {
for _, ch := range content {
if state.lineLength+1 > termWidth-5 {
if len(state.wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", state.wordBuffer, ch)
state.wordBuffer = ""
state.lineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(state.wordBuffer))
fmt.Printf("%s%c", state.wordBuffer, ch)
state.lineLength = len(state.wordBuffer) + 1
} else {
fmt.Print(string(ch))
state.lineLength += 1
switch ch {
case ' ':
state.wordBuffer = ""
case '\n':
state.lineLength = 0
default:
state.wordBuffer += string(ch)
}
}
}
} else {
fmt.Printf("%s%s", state.wordBuffer, content)
if len(state.wordBuffer) > 0 {
state.wordBuffer = ""
}
}
}
func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
client, err := api.ClientFromEnvironment()
if err != nil {
return nil, err
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
cancelCtx, cancel := context.WithCancel(cmd.Context())
defer cancel()
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT)
go func() {
<-sigChan
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var latest api.ChatResponse
var fullResponse strings.Builder
var role string
fn := func(response api.ChatResponse) error {
p.StopAndClear()
latest = response
role = response.Message.Role
content := response.Message.Content
fullResponse.WriteString(content)
displayResponse(content, opts.WordWrap, state)
return nil
}
req := &api.ChatRequest{
Model: opts.Model,
Messages: opts.Messages,
Format: opts.Format,
Options: opts.Options,
}
if err := client.Chat(cancelCtx, req, fn); err != nil {
if errors.Is(err, context.Canceled) {
return nil, nil
}
return nil, err
}
if len(opts.Messages) > 0 {
fmt.Println()
fmt.Println()
}
verbose, err := cmd.Flags().GetBool("verbose")
if err != nil {
return nil, err
}
if verbose {
latest.Summary()
}
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {
func generate(cmd *cobra.Command, opts generateOptions) error {
client, err := api.ClientFromEnvironment()
if err != nil {
return err
@@ -611,6 +490,11 @@ func generate(cmd *cobra.Command, opts runOptions) error {
generateContext = []int{}
}
termWidth, _, err := term.GetSize(int(os.Stdout.Fd()))
if err != nil {
opts.WordWrap = false
}
ctx, cancel := context.WithCancel(cmd.Context())
defer cancel()
@@ -622,35 +506,66 @@ func generate(cmd *cobra.Command, opts runOptions) error {
cancel()
}()
var state *displayResponseState = &displayResponseState{}
var currentLineLength int
var wordBuffer string
fn := func(response api.GenerateResponse) error {
p.StopAndClear()
latest = response
content := response.Response
displayResponse(content, opts.WordWrap, state)
termWidth, _, _ = term.GetSize(int(os.Stdout.Fd()))
if opts.WordWrap && termWidth >= 10 {
for _, ch := range response.Response {
if currentLineLength+1 > termWidth-5 {
if len(wordBuffer) > termWidth-10 {
fmt.Printf("%s%c", wordBuffer, ch)
wordBuffer = ""
currentLineLength = 0
continue
}
// backtrack the length of the last word and clear to the end of the line
fmt.Printf("\x1b[%dD\x1b[K\n", len(wordBuffer))
fmt.Printf("%s%c", wordBuffer, ch)
currentLineLength = len(wordBuffer) + 1
} else {
fmt.Print(string(ch))
currentLineLength += 1
switch ch {
case ' ':
wordBuffer = ""
case '\n':
currentLineLength = 0
default:
wordBuffer += string(ch)
}
}
}
} else {
fmt.Printf("%s%s", wordBuffer, response.Response)
if len(wordBuffer) > 0 {
wordBuffer = ""
}
}
return nil
}
if opts.MultiModal {
opts.Prompt, opts.Images, err = extractFileData(opts.Prompt)
if err != nil {
return err
}
images := make([]api.ImageData, 0)
for _, i := range opts.Images {
images = append(images, api.ImageData(i))
}
request := api.GenerateRequest{
Model: opts.Model,
Prompt: opts.Prompt,
Context: generateContext,
Images: opts.Images,
Format: opts.Format,
System: opts.System,
Template: opts.Template,
Options: opts.Options,
Images: images,
}
if err := client.Generate(ctx, &request, fn); err != nil {

View File

@@ -1,21 +1,19 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"github.com/spf13/cobra"
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/progress"
"github.com/jmorganca/ollama/readline"
)
@@ -28,82 +26,45 @@ const (
MultilineTemplate
)
func loadModel(cmd *cobra.Command, opts *runOptions) error {
func modelIsMultiModal(cmd *cobra.Command, name string) bool {
// get model details
client, err := api.ClientFromEnvironment()
if err != nil {
return err
fmt.Println("error: couldn't connect to ollama server")
return false
}
p := progress.NewProgress(os.Stderr)
defer p.StopAndClear()
spinner := progress.NewSpinner("")
p.Add("", spinner)
showReq := api.ShowRequest{Name: opts.Model}
showResp, err := client.Show(cmd.Context(), &showReq)
req := api.ShowRequest{Name: name}
resp, err := client.Show(cmd.Context(), &req)
if err != nil {
return err
}
opts.MultiModal = slices.Contains(showResp.Details.Families, "clip")
opts.ParentModel = showResp.Details.ParentModel
if len(showResp.Messages) > 0 {
opts.Messages = append(opts.Messages, showResp.Messages...)
return false
}
chatReq := &api.ChatRequest{
Model: opts.Model,
Messages: []api.Message{},
}
err = client.Chat(cmd.Context(), chatReq, func(resp api.ChatResponse) error {
p.StopAndClear()
if len(opts.Messages) > 0 {
for _, msg := range opts.Messages {
switch msg.Role {
case "user":
fmt.Printf(">>> %s\n", msg.Content)
case "assistant":
state := &displayResponseState{}
displayResponse(msg.Content, opts.WordWrap, state)
fmt.Println()
fmt.Println()
}
}
}
return nil
})
if err != nil {
return err
}
return nil
return slices.Contains(resp.Details.Families, "clip")
}
func generateInteractive(cmd *cobra.Command, opts runOptions) error {
opts.Messages = make([]api.Message, 0)
func generateInteractive(cmd *cobra.Command, opts generateOptions) error {
multiModal := modelIsMultiModal(cmd, opts.Model)
err := loadModel(cmd, &opts)
if err != nil {
// load the model
loadOpts := generateOptions{
Model: opts.Model,
Prompt: "",
Images: []ImageData{},
}
if err := generate(cmd, loadOpts); err != nil {
return err
}
usage := func() {
fmt.Fprintln(os.Stderr, "Available Commands:")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /load <model> Load a session or model")
fmt.Fprintln(os.Stderr, " /save <model> Save your current session")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, " /set Set session variables")
fmt.Fprintln(os.Stderr, " /show Show model information")
fmt.Fprintln(os.Stderr, " /bye Exit")
fmt.Fprintln(os.Stderr, " /?, /help Help for a command")
fmt.Fprintln(os.Stderr, " /? shortcuts Help for keyboard shortcuts")
fmt.Fprintln(os.Stderr, "")
fmt.Fprintln(os.Stderr, "Use \"\"\" to begin a multi-line message.")
if opts.MultiModal {
fmt.Fprintf(os.Stderr, "Use %s to include .jpg or .png images.\n", filepath.FromSlash("/path/to/file"))
}
fmt.Fprintln(os.Stderr, "")
}
@@ -213,7 +174,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
switch multiline {
case MultilineSystem:
opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.")
sb.Reset()
case MultilineTemplate:
@@ -233,6 +193,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Fprintln(&sb)
multiline = MultilinePrompt
scanner.Prompt.UseAlt = true
break
}
case scanner.Pasting:
fmt.Fprintln(&sb, line)
@@ -242,44 +203,6 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if err := ListHandler(cmd, args[1:]); err != nil {
return err
}
case strings.HasPrefix(line, "/load"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /load <modelname>")
continue
}
opts.Model = args[1]
opts.Messages = []api.Message{}
fmt.Printf("Loading model '%s'\n", opts.Model)
if err := loadModel(cmd, &opts); err != nil {
return err
}
continue
case strings.HasPrefix(line, "/save"):
args := strings.Fields(line)
if len(args) != 2 {
fmt.Println("Usage:\n /save <modelname>")
continue
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Println("error: couldn't connect to ollama server")
return err
}
req := &api.CreateRequest{
Name: args[1],
Modelfile: buildModelfile(opts),
}
fn := func(resp api.ProgressResponse) error { return nil }
err = client.Create(cmd.Context(), req, fn)
if err != nil {
fmt.Println("error: couldn't save model")
return err
}
fmt.Printf("Created new model '%s'\n", args[1])
continue
case strings.HasPrefix(line, "/set"):
args := strings.Fields(line)
if len(args) > 1 {
@@ -315,13 +238,16 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
usageParameters()
continue
}
params := args[3:]
var params []string
for _, p := range args[3:] {
params = append(params, p)
}
fp, err := api.FormatParams(map[string][]string{args[2]: params})
if err != nil {
fmt.Printf("Couldn't set parameter: %q\n", err)
fmt.Printf("Couldn't set parameter: %q\n\n", err)
continue
}
fmt.Printf("Set parameter '%s' to '%s'\n", args[2], strings.Join(params, ", "))
fmt.Printf("Set parameter '%s' to '%s'\n\n", args[2], strings.Join(params, ", "))
opts.Options[args[2]] = fp[args[2]]
case "system", "template":
if len(args) < 3 {
@@ -355,13 +281,10 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
if args[1] == "system" {
opts.System = sb.String()
opts.Messages = append(opts.Messages, api.Message{Role: "system", Content: opts.System})
fmt.Println("Set system message.")
sb.Reset()
} else if args[1] == "template" {
opts.Template = sb.String()
fmt.Println("Set prompt template.")
sb.Reset()
}
sb.Reset()
@@ -405,7 +328,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println("")
case "license":
if resp.License == "" {
fmt.Println("No license was specified for this model.")
fmt.Print("No license was specified for this model.\n\n")
} else {
fmt.Println(resp.License)
}
@@ -413,7 +336,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
fmt.Println(resp.Modelfile)
case "parameters":
if resp.Parameters == "" {
fmt.Println("No parameters were specified for this model.")
fmt.Print("No parameters were specified for this model.\n\n")
} else {
if len(opts.Options) > 0 {
fmt.Println("User defined parameters:")
@@ -432,7 +355,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
case resp.System != "":
fmt.Println(resp.System + "\n")
default:
fmt.Println("No system message was specified for this model.")
fmt.Print("No system message was specified for this model.\n\n")
}
case "template":
switch {
@@ -441,7 +364,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
case resp.Template != "":
fmt.Println(resp.Template)
default:
fmt.Println("No prompt template was specified for this model.")
fmt.Print("No prompt template was specified for this model.\n\n")
}
default:
fmt.Printf("Unknown command '/show %s'. Type /? for help\n", args[1])
@@ -469,7 +392,7 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
args := strings.Fields(line)
isFile := false
if opts.MultiModal {
if multiModal {
for _, f := range extractFileNames(line) {
if strings.HasPrefix(f, args[0]) {
isFile = true
@@ -489,72 +412,38 @@ func generateInteractive(cmd *cobra.Command, opts runOptions) error {
}
if sb.Len() > 0 && multiline == MultilineNone {
newMessage := api.Message{Role: "user", Content: sb.String()}
if opts.MultiModal {
msg, images, err := extractFileData(sb.String())
opts.Prompt = sb.String()
if multiModal {
newPrompt, images, err := extractFileData(sb.String())
if err != nil {
return err
}
opts.Prompt = newPrompt
// clear all previous images for better responses
// reset the context if we find another image
if len(images) > 0 {
for i := range opts.Messages {
opts.Messages[i].Images = nil
}
opts.Images = images
ctx := cmd.Context()
ctx = context.WithValue(ctx, generateContextKey("context"), []int{})
cmd.SetContext(ctx)
}
if len(opts.Images) == 0 {
fmt.Println("This model requires you to add a jpeg, png, or svg image.")
fmt.Println()
sb.Reset()
continue
}
newMessage.Content = msg
newMessage.Images = images
}
opts.Messages = append(opts.Messages, newMessage)
assistant, err := chat(cmd, opts)
if err != nil {
if err := generate(cmd, opts); err != nil {
return err
}
if assistant != nil {
opts.Messages = append(opts.Messages, *assistant)
}
sb.Reset()
}
}
}
func buildModelfile(opts runOptions) string {
var mf strings.Builder
model := opts.ParentModel
if model == "" {
model = opts.Model
}
fmt.Fprintf(&mf, "FROM %s\n", model)
if opts.System != "" {
fmt.Fprintf(&mf, "SYSTEM \"\"\"%s\"\"\"\n", opts.System)
}
if opts.Template != "" {
fmt.Fprintf(&mf, "TEMPLATE \"\"\"%s\"\"\"\n", opts.Template)
}
keys := make([]string, 0)
for k := range opts.Options {
keys = append(keys, k)
}
sort.Strings(keys)
for _, k := range keys {
fmt.Fprintf(&mf, "PARAMETER %s %v\n", k, opts.Options[k])
}
fmt.Fprintln(&mf)
for _, msg := range opts.Messages {
fmt.Fprintf(&mf, "MESSAGE %s \"\"\"%s\"\"\"\n", msg.Role, msg.Content)
}
return mf.String()
}
func normalizeFilePath(fp string) string {
// Define a map of escaped characters and their replacements
replacements := map[string]string{
@@ -590,9 +479,9 @@ func extractFileNames(input string) []string {
return re.FindAllString(input, -1)
}
func extractFileData(input string) (string, []api.ImageData, error) {
func extractFileData(input string) (string, []ImageData, error) {
filePaths := extractFileNames(input)
var imgs []api.ImageData
var imgs []ImageData
for _, fp := range filePaths {
nfp := normalizeFilePath(fp)
@@ -601,10 +490,10 @@ func extractFileData(input string) (string, []api.ImageData, error) {
if os.IsNotExist(err) {
continue
}
fmt.Fprintf(os.Stderr, "Couldn't process image: %q\n", err)
fmt.Printf("Couldn't process image: %q\n", err)
return "", imgs, err
}
fmt.Fprintf(os.Stderr, "Added image '%s'\n", nfp)
fmt.Printf("Added image '%s'\n", nfp)
input = strings.ReplaceAll(input, fp, "")
imgs = append(imgs, data)
}
@@ -625,7 +514,7 @@ func getImageData(filePath string) ([]byte, error) {
}
contentType := http.DetectContentType(buf)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
allowedTypes := []string{"image/jpeg", "image/jpg", "image/svg+xml", "image/png"}
if !slices.Contains(allowedTypes, contentType) {
return nil, fmt.Errorf("invalid image type: %s", contentType)
}

View File

@@ -1,13 +1,9 @@
package cmd
import (
"bytes"
"testing"
"text/template"
"github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
)
func TestExtractFilenames(t *testing.T) {
@@ -53,64 +49,3 @@ d:\path with\spaces\seven.svg inbetween7 c:\users\jdoe\eight.png inbetween8
assert.Contains(t, res[9], "ten.svg")
assert.Contains(t, res[9], "E:")
}
func TestModelfileBuilder(t *testing.T) {
opts := runOptions{
Model: "hork",
System: "You are part horse and part shark, but all hork. Do horklike things",
Template: "This is a template.",
Messages: []api.Message{
{Role: "user", Content: "Hey there hork!"},
{Role: "assistant", Content: "Yes it is true, I am half horse, half shark."},
},
Options: map[string]interface{}{},
}
opts.Options["temperature"] = 0.9
opts.Options["seed"] = 42
opts.Options["penalize_newline"] = false
opts.Options["stop"] = []string{"hi", "there"}
mf := buildModelfile(opts)
expectedModelfile := `FROM {{.Model}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err := template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var buf bytes.Buffer
err = tmpl.Execute(&buf, opts)
assert.Nil(t, err)
assert.Equal(t, buf.String(), mf)
opts.ParentModel = "horseshark"
mf = buildModelfile(opts)
expectedModelfile = `FROM {{.ParentModel}}
SYSTEM """{{.System}}"""
TEMPLATE """{{.Template}}"""
PARAMETER penalize_newline false
PARAMETER seed 42
PARAMETER stop [hi there]
PARAMETER temperature 0.9
MESSAGE user """Hey there hork!"""
MESSAGE assistant """Yes it is true, I am half horse, half shark."""
`
tmpl, err = template.New("").Parse(expectedModelfile)
assert.Nil(t, err)
var parentBuf bytes.Buffer
err = tmpl.Execute(&parentBuf, opts)
assert.Nil(t, err)
assert.Equal(t, parentBuf.String(), mf)
}

View File

@@ -10,7 +10,7 @@ Create new models or modify models already in the library using the Modelfile. L
Import models using source model weights found on Hugging Face and similar sites by referring to the **[Import Documentation](./import.md)**.
Installing on Linux in most cases is easy using the script on [ollama.com/download](ollama.com/download). To get more detail about the install, including CUDA drivers, see the **[Linux Documentation](./linux.md)**.
Installing on Linux in most cases is easy using the script on Ollama.ai. To get more detail about the install, including CUDA drivers, see the **[Linux Documentation](./linux.md)**.
Many of our users like the flexibility of using our official Docker Image. Learn more about using Docker with Ollama using the **[Docker Documentation](https://hub.docker.com/r/ollama/ollama)**.

View File

@@ -49,8 +49,7 @@ Advanced parameters (optional):
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `context`: the context parameter returned from a previous request to `/generate`, this can be used to keep a short conversational memory
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
- `raw`: if `true` no formatting will be applied to the prompt. You may choose to use the `raw` parameter if you are specifying a full templated prompt in your request to the API.
#### JSON mode
@@ -380,7 +379,6 @@ Advanced parameters (optional):
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `template`: the prompt template to use (overrides what is defined in the `Modelfile`)
- `stream`: if `false` the response will be returned as a single response object, rather than a stream of objects
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples
@@ -544,7 +542,7 @@ curl http://localhost:11434/api/chat -d '{
"role": "user",
"content": "what is in this image?",
"images": ["iVBORw0KGgoAAAANSUhEUgAAAG0AAABmCAYAAADBPx+VAAAACXBIWXMAAAsTAAALEwEAmpwYAAAAAXNSR0IArs4c6QAAAARnQU1BAACxjwv8YQUAAA3VSURBVHgB7Z27r0zdG8fX743i1bi1ikMoFMQloXRpKFFIqI7LH4BEQ+NWIkjQuSWCRIEoULk0gsK1kCBI0IhrQVT7tz/7zZo888yz1r7MnDl7z5xvsjkzs2fP3uu71nNfa7lkAsm7d++Sffv2JbNmzUqcc8m0adOSzZs3Z+/XES4ZckAWJEGWPiCxjsQNLWmQsWjRIpMseaxcuTKpG/7HP27I8P79e7dq1ars/yL4/v27S0ejqwv+cUOGEGGpKHR37tzJCEpHV9tnT58+dXXCJDdECBE2Ojrqjh071hpNECjx4cMHVycM1Uhbv359B2F79+51586daxN/+pyRkRFXKyRDAqxEp4yMlDDzXG1NPnnyJKkThoK0VFd1ELZu3TrzXKxKfW7dMBQ6bcuWLW2v0VlHjx41z717927ba22U9APcw7Nnz1oGEPeL3m3p2mTAYYnFmMOMXybPPXv2bNIPpFZr1NHn4HMw0KRBjg9NuRw95s8PEcz/6DZELQd/09C9QGq5RsmSRybqkwHGjh07OsJSsYYm3ijPpyHzoiacg35MLdDSIS/O1yM778jOTwYUkKNHWUzUWaOsylE00MyI0fcnOwIdjvtNdW/HZwNLGg+sR1kMepSNJXmIwxBZiG8tDTpEZzKg0GItNsosY8USkxDhD0Rinuiko2gfL/RbiD2LZAjU9zKQJj8RDR0vJBR1/Phx9+PHj9Z7REF4nTZkxzX4LCXHrV271qXkBAPGfP/atWvu/PnzHe4C97F48eIsRLZ9+3a3f/9+87dwP1JxaF7/3r17ba+5l4EcaVo0lj3SBq5kGTJSQmLWMjgYNei2GPT1MuMqGTDEFHzeQSP2wi/jGnkmPJ/nhccs44jvDAxpVcxnq0F6eT8h4ni/iIWpR5lPyA6ETkNXoSukvpJAD3AsXLiwpZs49+fPn5ke4j10TqYvegSfn0OnafC+Tv9ooA/JPkgQysqQNBzagXY55nO/oa1F7qvIPWkRL12WRpMWUvpVDYmxAPehxWSe8ZEXL20sadYIozfmNch4QJPAfeJgW3rNsnzphBKNJM2KKODo1rVOMRYik5ETy3ix4qWNI81qAAirizgMIc+yhTytx0JWZuNI03qsrgWlGtwjoS9XwgUhWGyhUaRZZQNNIEwCiXD16tXcAHUs79co0vSD8rrJCIW98pzvxpAWyyo3HYwqS0+H0BjStClcZJT5coMm6D2LOF8TolGJtK9fvyZpyiC5ePFi9nc/oJU4eiEP0jVoAnHa9wyJycITMP78+eMeP37sXrx44d6+fdt6f82aNdkx1pg9e3Zb5W+RSRE+n+VjksQWifvVaTKFhn5O8my63K8Qabdv33b379/PiAP//vuvW7BggZszZ072/+TJk91YgkafPn166zXB1rQHFvouAWHq9z3SEevSUerqCn2/dDCeta2jxYbr69evk4MHDyY7d+7MjhMnTiTPnz9Pfv/+nfQT2ggpO2dMF8cghuoM7Ygj5iWCqRlGFml0QC/ftGmTmzt3rmsaKDsgBSPh0/8yPeLLBihLkOKJc0jp8H8vUzcxIA1k6QJ/c78tWEyj5P3o4u9+jywNPdJi5rAH9x0KHcl4Hg570eQp3+vHXGyrmEeigzQsQsjavXt38ujRo44LQuDDhw+TW7duRS1HGgMxhNXHgflaNTOsHyKvHK5Ijo2jbFjJBQK9YwFd6RVMzfgRBmEfP37suBBm/p49e1qjEP2mwTViNRo0VJWH1deMXcNK08uUjVUu7s/zRaL+oLNxz1bpANco4npUgX4G2eFbpDFyQoQxojBCpEGSytmOH8qrH5Q9vuzD6ofQylkCUmh8DBAr+q8JCyVNtWQIidKQE9wNtLSQnS4jDSsxNHogzFuQBw4cyM61UKVsjfr3ooBkPSqqQHesUPWVtzi9/vQi1T+rJj7WiTz4Pt/l3LxUkr5P2VYZaZ4URpsE+st/dujQoaBBYokbrz/8TJNQYLSonrPS9kUaSkPeZyj1AWSj+d+VBoy1pIWVNed8P0Ll/ee5HdGRhrHhR5GGN0r4LGZBaj8oFDJitBTJzIZgFcmU0Y8ytWMZMzJOaXUSrUs5RxKnrxmbb5YXO9VGUhtpXldhEUogFr3IzIsvlpmdosVcGVGXFWp2oU9kLFL3dEkSz6NHEY1sjSRdIuDFWEhd8KxFqsRi1uM/nz9/zpxnwlESONdg6dKlbsaMGS4EHFHtjFIDHwKOo46l4TxSuxgDzi+rE2jg+BaFruOX4HXa0Nnf1lwAPufZeF8/r6zD97WK2qFnGjBxTw5qNGPxT+5T/r7/7RawFC3j4vTp09koCxkeHjqbHJqArmH5UrFKKksnxrK7FuRIs8STfBZv+luugXZ2pR/pP9Ois4z+TiMzUUkUjD0iEi1fzX8GmXyuxUBRcaUfykV0YZnlJGKQpOiGB76x5GeWkWWJc3mOrK6S7xdND+W5N6XyaRgtWJFe13GkaZnKOsYqGdOVVVbGupsyA/l7emTLHi7vwTdirNEt0qxnzAvBFcnQF16xh/TMpUuXHDowhlA9vQVraQhkudRdzOnK+04ZSP3DUhVSP61YsaLtd/ks7ZgtPcXqPqEafHkdqa84X6aCeL7YWlv6edGFHb+ZFICPlljHhg0bKuk0CSvVznWsotRu433alNdFrqG45ejoaPCaUkWERpLXjzFL2Rpllp7PJU2a/v7Ab8N05/9t27Z16KUqoFGsxnI9EosS2niSYg9SpU6B4JgTrvVW1flt1sT+0ADIJU2maXzcUTraGCRaL1Wp9rUMk16PMom8QhruxzvZIegJjFU7LLCePfS8uaQdPny4jTTL0dbee5mYokQsXTIWNY46kuMbnt8Kmec+LGWtOVIl9cT1rCB0V8WqkjAsRwta93TbwNYoGKsUSChN44lgBNCoHLHzquYKrU6qZ8lolCIN0Rh6cP0Q3U6I6IXILYOQI513hJaSKAorFpuHXJNfVlpRtmYBk1Su1obZr5dnKAO+L10Hrj3WZW+E3qh6IszE37F6EB+68mGpvKm4eb9bFrlzrok7fvr0Kfv727dvWRmdVTJHw0qiiCUSZ6wCK+7XL/AcsgNyL74DQQ730sv78Su7+t/A36MdY0sW5o40ahslXr58aZ5HtZB8GH64m9EmMZ7FpYw4T6QnrZfgenrhFxaSiSGXtPnz57e9TkNZLvTjeqhr734CNtrK41L40sUQckmj1lGKQ0rC37x544r8eNXRpnVE3ZZY7zXo8NomiO0ZUCj2uHz58rbXoZ6gc0uA+F6ZeKS/jhRDUq8MKrTho9fEkihMmhxtBI1DxKFY9XLpVcSkfoi8JGnToZO5sU5aiDQIW716ddt7ZLYtMQlhECdBGXZZMWldY5BHm5xgAroWj4C0hbYkSc/jBmggIrXJWlZM6pSETsEPGqZOndr2uuuR5rF169a2HoHPdurUKZM4CO1WTPqaDaAd+GFGKdIQkxAn9RuEWcTRyN2KSUgiSgF5aWzPTeA/lN5rZubMmR2bE4SIC4nJoltgAV/dVefZm72AtctUCJU2CMJ327hxY9t7EHbkyJFseq+EJSY16RPo3Dkq1kkr7+q0bNmyDuLQcZBEPYmHVdOBiJyIlrRDq41YPWfXOxUysi5fvtyaj+2BpcnsUV/oSoEMOk2CQGlr4ckhBwaetBhjCwH0ZHtJROPJkyc7UjcYLDjmrH7ADTEBXFfOYmB0k9oYBOjJ8b4aOYSe7QkKcYhFlq3QYLQhSidNmtS2RATwy8YOM3EQJsUjKiaWZ+vZToUQgzhkHXudb/PW5YMHD9yZM2faPsMwoc7RciYJXbGuBqJ1UIGKKLv915jsvgtJxCZDubdXr165mzdvtr1Hz5LONA8jrUwKPqsmVesKa49S3Q4WxmRPUEYdTjgiUcfUwLx589ySJUva3oMkP6IYddq6HMS4o55xBJBUeRjzfa4Zdeg56QZ43LhxoyPo7Lf1kNt7oO8wWAbNwaYjIv5lhyS7kRf96dvm5Jah8vfvX3flyhX35cuX6HfzFHOToS1H4BenCaHvO8pr8iDuwoUL7tevX+b5ZdbBair0xkFIlFDlW4ZknEClsp/TzXyAKVOmmHWFVSbDNw1l1+4f90U6IY/q4V27dpnE9bJ+v87QEydjqx/UamVVPRG+mwkNTYN+9tjkwzEx+atCm/X9WvWtDtAb68Wy9LXa1UmvCDDIpPkyOQ5ZwSzJ4jMrvFcr0rSjOUh+GcT4LSg5ugkW1Io0/SCDQBojh0hPlaJdah+tkVYrnTZowP8iq1F1TgMBBauufyB33x1v+NWFYmT5KmppgHC+NkAgbmRkpD3yn9QIseXymoTQFGQmIOKTxiZIWpvAatenVqRVXf2nTrAWMsPnKrMZHz6bJq5jvce6QK8J1cQNgKxlJapMPdZSR64/UivS9NztpkVEdKcrs5alhhWP9NeqlfWopzhZScI6QxseegZRGeg5a8C3Re1Mfl1ScP36ddcUaMuv24iOJtz7sbUjTS4qBvKmstYJoUauiuD3k5qhyr7QdUHMeCgLa1Ear9NquemdXgmum4fvJ6w1lqsuDhNrg1qSpleJK7K3TF0Q2jSd94uSZ60kK1e3qyVpQK6PVWXp2/FC3mp6jBhKKOiY2h3gtUV64TWM6wDETRPLDfSakXmH3w8g9Jlug8ZtTt4kVF0kLUYYmCCtD/DrQ5YhMGbA9L3ucdjh0y8kOHW5gU/VEEmJTcL4Pz/f7mgoAbYkAAAAAElFTkSuQmCC"]
}
},
]
}'
```
@@ -960,7 +958,6 @@ Generate embeddings from a model
Advanced parameters:
- `options`: additional model parameters listed in the documentation for the [Modelfile](./modelfile.md#valid-parameters-and-values) such as `temperature`
- `keep_alive`: controls how long the model will stay loaded into memory following the request (default: `5m`)
### Examples

View File

@@ -1,9 +1,13 @@
# Development
- Install cmake or (optionally, required tools for GPUs)
- run `go generate ./...`
- run `go build .`
Install required tools:
- cmake version 3.24 or higher
- go version 1.21 or higher
- go version 1.20 or higher
- gcc version 11.4.0 or higher
```bash
@@ -13,11 +17,7 @@ brew install go cmake gcc
Optionally enable debugging and more verbose logging:
```bash
# At build time
export CGO_CFLAGS="-g"
# At runtime
export OLLAMA_DEBUG=1
```
Get the required libraries and build the native LLM code:
@@ -44,15 +44,7 @@ Now you can run `ollama`:
*Your operating system distribution may already have packages for NVIDIA CUDA. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!*
Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads)
development and runtime packages.
Typically the build scripts will auto-detect CUDA, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `CUDA_LIB_DIR` to the location of the shared
libraries, and `CUDACXX` to the location of the nvcc compiler. You can customize
set set of target CUDA architectues by setting `CMAKE_CUDA_ARCHITECTURES` (e.g. "50;60;70")
Install `cmake` and `golang` as well as [NVIDIA CUDA](https://developer.nvidia.com/cuda-downloads) development and runtime packages.
Then generate dependencies:
```
@@ -70,16 +62,10 @@ go build .
*Your operating system distribution may already have packages for AMD ROCm and CLBlast. Distro packages are often preferable, but instructions are distro-specific. Please consult distro-specific docs for dependencies if available!*
Install [CLBlast](https://github.com/CNugteren/CLBlast/blob/master/doc/installation.md) and [ROCm](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) developement packages first, as well as `cmake` and `golang`.
Typically the build scripts will auto-detect ROCm, however, if your Linux distro
or installation approach uses unusual paths, you can specify the location by
specifying an environment variable `ROCM_PATH` to the location of the ROCm
install (typically `/opt/rocm`), and `CLBlast_DIR` to the location of the
CLBlast install (typically `/usr/lib/cmake/CLBlast`). You can also customize
the AMD GPU targets by setting AMDGPU_TARGETS (e.g. `AMDGPU_TARGETS="gfx1101;gfx1102"`)
Adjust the paths below (correct for Arch) as appropriate for your distributions install locations and generate dependencies:
```
go generate ./...
CLBlast_DIR=/usr/lib/cmake/CLBlast ROCM_PATH=/opt/rocm go generate ./...
```
Then build the binary:
@@ -90,22 +76,6 @@ go build .
ROCm requires elevated privileges to access the GPU at runtime. On most distros you can add your user account to the `render` group, or run as root.
#### Advanced CPU Settings
By default, running `go generate ./...` will compile a few different variations
of the LLM library based on common CPU families and vector math capabilities,
including a lowest-common-denominator which should run on almost any 64 bit CPU
somewhat slowly. At runtime, Ollama will auto-detect the optimal variation to
load. If you would like to build a CPU-based build customized for your
processor, you can set `OLLAMA_CUSTOM_CPU_DEFS` to the llama.cpp flags you would
like to use. For example, to compile an optimized binary for an Intel i9-9880H,
you might use:
```
OLLAMA_CUSTOM_CPU_DEFS="-DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_F16C=on -DLLAMA_FMA=on" go generate ./...
go build .
```
#### Containerized Linux Build
If you have Docker available, you can build linux binaries with `./scripts/build_linux.sh` which has the CUDA and ROCm dependencies included. The resulting binary is placed in `./dist`
@@ -118,7 +88,7 @@ Note: The windows build for Ollama is still under development.
Install required tools:
- MSVC toolchain - C/C++ and cmake as minimal requirements
- go version 1.21 or higher
- go version 1.20 or higher
- MinGW (pick one variant) with GCC.
- <https://www.mingw-w64.org/>
- <https://www.msys2.org/>

View File

@@ -8,38 +8,35 @@ To upgrade Ollama, run the installation process again. On the Mac, click the Oll
Review the [Troubleshooting](./troubleshooting.md) docs for more about using logs.
## How do I configure Ollama server?
## How do I use Ollama server environment variables on Mac
Ollama server can be configured with environment variables.
On macOS, Ollama runs in the background and is managed by the menubar app. If adding environment variables, Ollama will need to be run manually.
### Setting environment variables on Mac
1. Click the menubar icon for Ollama and choose **Quit Ollama**.
2. Open a new terminal window and run the following command (this example uses `OLLAMA_HOST` with an IP address of `123.1.1.1`):
If Ollama is run as a macOS application, environment variables should be set using `launchctl`:
```bash
OLLAMA_HOST=123.1.1.1 ollama serve
```
1. For each environment variable, call `launchctl setenv`.
## How do I use Ollama server environment variables on Linux?
```bash
launchctl setenv OLLAMA_HOST "0.0.0.0"
```
If Ollama is installed with the install script, a systemd service was created, running as the Ollama user. To add an environment variable, such as OLLAMA_HOST, follow these steps:
2. Restart Ollama application.
1. Create a `systemd` drop-in directory and add a config file. This is only needed once.
### Setting environment variables on Linux
```bash
mkdir -p /etc/systemd/system/ollama.service.d
echo '[Service]' >>/etc/systemd/system/ollama.service.d/environment.conf
```
If Ollama is run as a systemd service, environment variables should be set using `systemctl`:
2. For each environment variable, add it to the config file:
1. Edit the systemd service by calling `systemctl edit ollama.service`. This will open an editor.
```bash
echo 'Environment="OLLAMA_HOST=0.0.0.0:11434"' >>/etc/systemd/system/ollama.service.d/environment.conf
```
2. For each environment variable, add a line `Environment` under section `[Service]`:
```ini
[Service]
Environment="OLLAMA_HOST=0.0.0.0"
```
3. Save and exit.
4. Reload `systemd` and restart Ollama:
3. Reload `systemd` and restart Ollama:
```bash
systemctl daemon-reload
@@ -48,26 +45,26 @@ If Ollama is run as a systemd service, environment variables should be set using
## How can I expose Ollama on my network?
Ollama binds 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable.
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
Ollama binds to 127.0.0.1 port 11434 by default. Change the bind address with the `OLLAMA_HOST` environment variable. Refer to the section above for how to use environment variables on your platform.
## How can I allow additional web origins to access Ollama?
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Additional origins can be configured with `OLLAMA_ORIGINS`.
Ollama allows cross-origin requests from `127.0.0.1` and `0.0.0.0` by default. Add additional origins with the `OLLAMA_ORIGINS` environment variable. For example, to add all ports on 192.168.1.1 and https://example.com, use:
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
```shell
OLLAMA_ORIGINS=http://192.168.1.1:*,https://example.com
```
Refer to the section above for how to use environment variables on your platform.
## Where are models stored?
- macOS: `~/.ollama/models`.
- Linux: `/usr/share/ollama/.ollama/models`
### How do I set them to a different location?
## How do I set them to a different location?
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory.
Refer to the section [above](#how-do-i-configure-ollama-server) for how to set environment variables on your platform.
If a different directory needs to be used, set the environment variable `OLLAMA_MODELS` to the chosen directory. Refer to the section above for how to use environment variables on your platform.
## Does Ollama send my prompts and answers back to Ollama.ai to use in any way?

View File

@@ -15,7 +15,7 @@ FROM ./mistral-7b-v0.1.Q4_0.gguf
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
```
FROM ./mistral-7b-v0.1.Q4_0.gguf
FROM ./q4_0.bin
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
```
@@ -37,69 +37,55 @@ ollama run example "What is your favourite condiment?"
## Importing (PyTorch & Safetensors)
> Importing from PyTorch and Safetensors is a longer process than importing from GGUF. Improvements that make it easier are a work in progress.
### Supported models
### Setup
Ollama supports a set of model architectures, with support for more coming soon:
First, clone the `ollama/ollama` repo:
- Llama & Mistral
- Falcon & RW
- BigCode
```
git clone git@github.com:ollama/ollama.git ollama
cd ollama
```
To view a model's architecture, check the `config.json` file in its HuggingFace repo. You should see an entry under `architectures` (e.g. `LlamaForCausalLM`).
and then fetch its `llama.cpp` submodule:
```shell
git submodule init
git submodule update llm/llama.cpp
```
Next, install the Python dependencies:
```
python3 -m venv llm/llama.cpp/.venv
source llm/llama.cpp/.venv/bin/activate
pip install -r llm/llama.cpp/requirements.txt
```
Then build the `quantize` tool:
```
make -C llm/llama.cpp quantize
```
### Clone the HuggingFace repository (optional)
### Step 1: Clone the HuggingFace repository (optional)
If the model is currently hosted in a HuggingFace repository, first clone that repository to download the raw model.
Install [Git LFS](https://docs.github.com/en/repositories/working-with-files/managing-large-files/installing-git-large-file-storage), verify it's installed, and then clone the model's repository:
```
git lfs install
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1 model
git clone https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1
cd Mistral-7B-Instruct-v0.1
```
### Convert the model
### Step 2: Convert and quantize to a `.bin` file (optional, for PyTorch and Safetensors)
> Note: some model architectures require using specific convert scripts. For example, Qwen models require running `convert-hf-to-gguf.py` instead of `convert.py`
If the model is in PyTorch or Safetensors format, a [Docker image](https://hub.docker.com/r/ollama/quantize) with the tooling required to convert and quantize models is available.
First, Install [Docker](https://www.docker.com/get-started/).
Next, to convert and quantize your model, run:
```
python llm/llama.cpp/convert.py ./model --outtype f16 --outfile converted.bin
docker run --rm -v .:/model ollama/quantize -q q4_0 /model
```
### Quantize the model
This will output two files into the directory:
```
llm/llama.cpp/quantize converted.bin quantized.bin q4_0
```
- `f16.bin`: the model converted to GGUF
- `q4_0.bin` the model quantized to a 4-bit quantization (Ollama will use this file to create the Ollama model)
### Step 3: Write a `Modelfile`
Next, create a `Modelfile` for your model:
```
FROM quantized.bin
FROM ./q4_0.bin
```
(Optional) many chat models require a prompt template in order to answer correctly. A default prompt template can be specified with the `TEMPLATE` instruction in the `Modelfile`:
```
FROM ./q4_0.bin
TEMPLATE "[INST] {{ .Prompt }} [/INST]"
```
@@ -123,9 +109,9 @@ ollama run example "What is your favourite condiment?"
Publishing models is in early alpha. If you'd like to publish your model to share with others, follow these steps:
1. Create [an account](https://ollama.com/signup)
1. Create [an account](https://ollama.ai/signup)
2. Run `cat ~/.ollama/id_ed25519.pub` to view your Ollama public key. Copy this to the clipboard.
3. Add your public key to your [Ollama account](https://ollama.com/settings/keys)
3. Add your public key to your [Ollama account](https://ollama.ai/settings/keys)
Next, copy your model to your username's namespace:
@@ -139,7 +125,7 @@ Then push the model:
ollama push <your username>/example
```
After publishing, your model will be available at `https://ollama.com/<your username>/example`.
After publishing, your model will be available at `https://ollama.ai/<your username>/example`.
## Quantization reference
@@ -163,3 +149,47 @@ The quantization options are as follow (from highest highest to lowest levels of
- `q6_K`
- `q8_0`
- `f16`
## Manually converting & quantizing models
### Prerequisites
Start by cloning the `llama.cpp` repo to your machine in another directory:
```
git clone https://github.com/ggerganov/llama.cpp.git
cd llama.cpp
```
Next, install the Python dependencies:
```
pip install -r requirements.txt
```
Finally, build the `quantize` tool:
```
make quantize
```
### Convert the model
Run the correct conversion script for your model architecture:
```shell
# LlamaForCausalLM or MistralForCausalLM
python convert.py <path to model directory>
# FalconForCausalLM
python convert-falcon-hf-to-gguf.py <path to model directory>
# GPTBigCodeForCausalLM
python convert-starcoder-hf-to-gguf.py <path to model directory>
```
### Quantize the model
```
quantize <path to model dir>/ggml-model-f32.bin <path to model dir>/q4_0.bin q4_0
```

View File

@@ -3,11 +3,9 @@
## Install
Install Ollama running this one-liner:
>
```bash
curl -fsSL https://ollama.com/install.sh | sh
curl https://ollama.ai/install.sh | sh
```
## Manual install
@@ -17,7 +15,7 @@ curl -fsSL https://ollama.com/install.sh | sh
Ollama is distributed as a self-contained binary. Download it to a directory in your PATH:
```bash
sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
sudo chmod +x /usr/bin/ollama
```
@@ -77,13 +75,13 @@ sudo systemctl start ollama
Update ollama by running the install script again:
```bash
curl -fsSL https://ollama.com/install.sh | sh
curl https://ollama.ai/install.sh | sh
```
Or by downloading the ollama binary:
```bash
sudo curl -L https://ollama.com/download/ollama-linux-amd64 -o /usr/bin/ollama
sudo curl -L https://ollama.ai/download/ollama-linux-amd64 -o /usr/bin/ollama
sudo chmod +x /usr/bin/ollama
```
@@ -111,10 +109,8 @@ Remove the ollama binary from your bin directory (either `/usr/local/bin`, `/usr
sudo rm $(which ollama)
```
Remove the downloaded models and Ollama service user and group:
Remove the downloaded models and Ollama service user:
```bash
sudo rm -r /usr/share/ollama
sudo userdel ollama
sudo groupdel ollama
```

View File

@@ -19,7 +19,6 @@ A model file is the blueprint to create and share models with Ollama.
- [SYSTEM](#system)
- [ADAPTER](#adapter)
- [LICENSE](#license)
- [MESSAGE](#message)
- [Notes](#notes)
## Format
@@ -39,7 +38,6 @@ INSTRUCTION arguments
| [`SYSTEM`](#system) | Specifies the system message that will be set in the template. |
| [`ADAPTER`](#adapter) | Defines the (Q)LoRA adapters to apply to the model. |
| [`LICENSE`](#license) | Specifies the legal license. |
| [`MESSAGE`](#message) | Specify message history. |
## Examples
@@ -67,13 +65,13 @@ To use this:
More examples are available in the [examples directory](../examples).
### `Modelfile`s in [ollama.com/library][1]
### `Modelfile`s in [ollama.ai/library][1]
There are two ways to view `Modelfile`s underlying the models in [ollama.com/library][1]:
There are two ways to view `Modelfile`s underlying the models in [ollama.ai/library][1]:
- Option 1: view a details page from a model's tags page:
1. Go to a particular model's tags (e.g. https://ollama.com/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.com/library/llama2:13b)
1. Go to a particular model's tags (e.g. https://ollama.ai/library/llama2/tags)
2. Click on a tag (e.g. https://ollama.ai/library/llama2:13b)
3. Scroll down to "Layers"
- Note: if the [`FROM` instruction](#from-required) is not present,
it means the model was created from a local file
@@ -86,7 +84,7 @@ There are two ways to view `Modelfile`s underlying the models in [ollama.com/lib
# FROM llama2:13b
FROM /root/.ollama/models/blobs/sha256:123abc
TEMPLATE """[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>>
TEMPLATE """[INST] {{ if and .First .System }}<<SYS>>{{ .System }}<</SYS>>
{{ end }}{{ .Prompt }} [/INST] """
SYSTEM """"""
@@ -154,23 +152,31 @@ PARAMETER <parameter> <parametervalue>
### TEMPLATE
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message, a user's message and the response from the model. Note: syntax may be model specific. Templates use Go [template syntax](https://pkg.go.dev/text/template).
`TEMPLATE` of the full prompt template to be passed into the model. It may include (optionally) a system message and a user's prompt. This is used to create a full custom prompt, and syntax may be model specific. You can usually find the template for a given model in the readme for that model.
#### Template Variables
| Variable | Description |
| ----------------- | --------------------------------------------------------------------------------------------- |
| `{{ .System }}` | The system message used to specify custom behavior. |
| `{{ .Prompt }}` | The user prompt message. |
| `{{ .Response }}` | The response from the model. When generating a response, text after this variable is omitted. |
| Variable | Description |
| ----------------- | ------------------------------------------------------------------------------------------------------------- |
| `{{ .System }}` | The system message used to specify custom behavior, this must also be set in the Modelfile as an instruction. |
| `{{ .Prompt }}` | The incoming prompt, this is not specified in the model file and will be set based on input. |
| `{{ .Response }}` | The response from the LLM, if not specified response is appended to the end of the template. |
| `{{ .First }}` | A boolean value used to render specific template information for the first generation of a session. |
```
TEMPLATE """{{ if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
```modelfile
TEMPLATE """
{{- if .First }}
### System:
{{ .System }}
{{- end }}
### User:
{{ .Prompt }}
### Response:
"""
SYSTEM """<system message>"""
```
### SYSTEM
@@ -199,22 +205,9 @@ LICENSE """
"""
```
### MESSAGE
The `MESSAGE` instruction allows you to specify a message history for the model to use when responding:
```modelfile
MESSAGE user Is Toronto in Canada?
MESSAGE assistant yes
MESSAGE user Is Sacramento in Canada?
MESSAGE assistant no
MESSAGE user Is Ontario in Canada?
MESSAGE assistant yes
```
## Notes
- the **`Modelfile` is not case sensitive**. In the examples, uppercase instructions are used to make it easier to distinguish it from arguments.
- Instructions can be in any order. In the examples, the `FROM` instruction is first to keep it easily readable.
[1]: https://ollama.com/library
[1]: https://ollama.ai/library

View File

@@ -1,141 +0,0 @@
# OpenAI compatibility
> **Note:** OpenAI compatibility is experimental and is subject to major adjustments including breaking changes. For fully-featured access to the Ollama API, see the Ollama [Python library](https://github.com/ollama/ollama-python), [JavaScript library](https://github.com/ollama/ollama-js) and [REST API](https://github.com/jmorganca/ollama/blob/main/docs/api.md).
Ollama provides experimental compatibility with parts of the [OpenAI API](https://platform.openai.com/docs/api-reference) to help connect existing applications to Ollama.
## Usage
### OpenAI Python library
```python
from openai import OpenAI
client = OpenAI(
base_url='http://localhost:11434/v1/',
# required but ignored
api_key='ollama',
)
chat_completion = client.chat.completions.create(
messages=[
{
'role': 'user',
'content': 'Say this is a test',
}
],
model='llama2',
)
```
### OpenAI JavaScript library
```javascript
import OpenAI from 'openai'
const openai = new OpenAI({
baseURL: 'http://localhost:11434/v1/',
// required but ignored
apiKey: 'ollama',
})
const chatCompletion = await openai.chat.completions.create({
messages: [{ role: 'user', content: 'Say this is a test' }],
model: 'llama2',
})
```
### `curl`
```
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "llama2",
"messages": [
{
"role": "system",
"content": "You are a helpful assistant."
},
{
"role": "user",
"content": "Hello!"
}
]
}'
```
## Endpoints
### `/v1/chat/completions`
#### Supported features
- [x] Chat completions
- [x] Streaming
- [x] JSON mode
- [x] Reproducible outputs
- [ ] Vision
- [ ] Function calling
- [ ] Logprobs
#### Supported request fields
- [x] `model`
- [x] `messages`
- [x] Text `content`
- [ ] Array of `content` parts
- [x] `frequency_penalty`
- [x] `presence_penalty`
- [x] `response_format`
- [x] `seed`
- [x] `stop`
- [x] `stream`
- [x] `temperature`
- [x] `top_p`
- [x] `max_tokens`
- [ ] `logit_bias`
- [ ] `tools`
- [ ] `tool_choice`
- [ ] `user`
- [ ] `n`
#### Notes
- Setting `seed` will always set `temperature` to `0`
- `finish_reason` will always be `stop`
- `usage.prompt_tokens` will be 0 for completions where prompt evaluation is cached
## Models
Before using a model, pull it locally `ollama pull`:
```shell
ollama pull llama2
```
### Default model names
For tooling that relies on default OpenAI model names such as `gpt-3.5-turbo`, use `ollama cp` to copy an existing model name to a temporary name:
```
ollama cp llama2 gpt-3.5-turbo
```
Afterwards, this new model name can be specified the `model` field:
```shell
curl http://localhost:11434/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello!"
}
]
}'
```

View File

@@ -12,49 +12,11 @@ On Linux systems with systemd, the logs can be found with this command:
journalctl -u ollama
```
When you run Ollama in a container, the logs go to stdout/stderr in the container:
```shell
docker logs <container-name>
```
(Use `docker ps` to find the container name)
If manually running `ollama serve` in a terminal, the logs will be on that terminal.
Join the [Discord](https://discord.gg/ollama) for help interpreting the logs.
## LLM libraries
Ollama includes multiple LLM libraries compiled for different GPUs and CPU
vector features. Ollama tries to pick the best one based on the capabilities of
your system. If this autodetection has problems, or you run into other problems
(e.g. crashes in your GPU) you can workaround this by forcing a specific LLM
library. `cpu_avx2` will perform the best, followed by `cpu_avx` an the slowest
but most compatible is `cpu`. Rosetta emulation under MacOS will work with the
`cpu` library.
In the server log, you will see a message that looks something like this (varies
from release to release):
```
Dynamic LLM libraries [rocm_v6 cpu cpu_avx cpu_avx2 cuda_v11 rocm_v5]
```
**Experimental LLM Library Override**
You can set OLLAMA_LLM_LIBRARY to any of the available LLM libraries to bypass
autodetection, so for example, if you have a CUDA card, but want to force the
CPU LLM library with AVX2 vector support, use:
```
OLLAMA_LLM_LIBRARY="cpu_avx2" ollama serve
```
You can see what features your CPU has with the following.
```
cat /proc/cpuinfo| grep flags | head -1
```
## Known issues
* N/A
* `signal: illegal instruction (core dumped)`: Ollama requires AVX support from the CPU. This was introduced in 2011 and CPUs started offering it in 2012. CPUs from before that and some lower end CPUs after that may not have AVX support and thus are not supported by Ollama. Some users have had luck with building Ollama on their machines disabling the need for AVX.

View File

@@ -17,7 +17,7 @@ Prerequisites:
Here are the steps:
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.com/install.sh | sh`
- Install Ollama via standard Linux command (ignore the 404 error): `curl https://ollama.ai/install.sh | sh`
- Stop the Ollama service: `sudo systemctl stop ollama`
- Start Ollama serve in a tmux session called ollama_jetson and reference the CUDA libraries path: `tmux has-session -t ollama_jetson 2>/dev/null || tmux new-session -d -s ollama_jetson
'LD_LIBRARY_PATH=/usr/local/cuda/lib64 ollama serve'`

View File

@@ -8,7 +8,7 @@
"outputs": [],
"source": [
"# Download and run the Ollama Linux install script\n",
"!curl -fsSL https://ollama.com/install.sh | sh\n",
"!curl https://ollama.ai/install.sh | sh\n",
"!command -v systemctl >/dev/null && sudo systemctl stop ollama"
]
},

View File

@@ -2,28 +2,28 @@
## Prerequisites
- Ollama: https://ollama.com/download
- Ollama: https://ollama.ai/download
- Kubernetes cluster. This example will use Google Kubernetes Engine.
## Steps
1. Create the Ollama namespace, daemon set, and service
```bash
kubectl apply -f cpu.yaml
```
```bash
kubectl apply -f cpu.yaml
```
1. Port forward the Ollama service to connect and use it locally
```bash
kubectl -n ollama port-forward service/ollama 11434:80
```
```bash
kubectl -n ollama port-forward service/ollama 11434:80
```
1. Pull and run a model, for example `orca-mini:3b`
```bash
ollama run orca-mini:3b
```
```bash
ollama run orca-mini:3b
```
## (Optional) Hardware Acceleration

View File

@@ -1,6 +1,6 @@
# LangChain Web Summarization
This example summarizes the website, [https://ollama.com/blog/run-llama2-uncensored-locally](https://ollama.com/blog/run-llama2-uncensored-locally)
This example summarizes the website, [https://ollama.ai/blog/run-llama2-uncensored-locally](https://ollama.ai/blog/run-llama2-uncensored-locally)
## Running the Example

View File

@@ -2,7 +2,7 @@ from langchain.llms import Ollama
from langchain.document_loaders import WebBaseLoader
from langchain.chains.summarize import load_summarize_chain
loader = WebBaseLoader("https://ollama.com/blog/run-llama2-uncensored-locally")
loader = WebBaseLoader("https://ollama.ai/blog/run-llama2-uncensored-locally")
docs = loader.load()
llm = Ollama(model="llama2")

View File

@@ -40,13 +40,13 @@ You are a log file analyzer. You will receive a set of lines from a log file for
"""
```
This model is available at https://ollama.com/mattw/loganalyzer. You can customize it and add to your own namespace using the command `ollama create <namespace/modelname> -f <path-to-modelfile>` then `ollama push <namespace/modelname>`.
This model is available at https://ollama.ai/mattw/loganalyzer. You can customize it and add to your own namespace using the command `ollama create <namespace/modelname> -f <path-to-modelfile>` then `ollama push <namespace/modelname>`.
Then loganalysis.py scans all the lines in the given log file and searches for the word 'error'. When the word is found, the 10 lines before and after are set as the prompt for a call to the Generate API.
```python
data = {
"prompt": "\n".join(error_logs),
"prompt": "\n".join(error_logs),
"model": "mattw/loganalyzer"
}
```

View File

@@ -29,9 +29,9 @@ You can also add your own character to be chosen at random when you ask a questi
```bash
ollama pull stablebeluga2:70b-q4_K_M
```
2. Create a new character:
```bash
npm run charactergen "Lorne Greene"
```
@@ -41,15 +41,15 @@ You can also add your own character to be chosen at random when you ask a questi
3. Now you can create a model with this command:
```bash
ollama create <username>/lornegreene -f lornegreene/Modelfile
ollama create <YourNamespace>/lornegreene -f lornegreene/Modelfile
```
`username` is whatever name you set up when you signed up at [https://ollama.com/signup](https://ollama.com/signup).
`YourNamespace` is whatever name you set up when you signed up at [https://ollama.ai/signup](https://ollama.ai/signup).
4. To add this to your mentors, you will have to update the code as follows. On line 8 of `mentors.ts`, add an object to the array, replacing `<username>` with the username you used above.
4. To add this to your mentors, you will have to update the code as follows. On line 8 of `mentors.ts`, add an object to the array, replacing `<YourNamespace>` with the namespace you used above.
```bash
{ns: "<username>", char: "Lorne Greene"}
{ns: "<YourNamespace>", char: "Lorne Greene"}
```
## Review the Code

4
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/jmorganca/ollama
go 1.21
go 1.20
require (
github.com/emirpasic/gods v1.18.1
@@ -45,7 +45,7 @@ require (
golang.org/x/crypto v0.14.0
golang.org/x/exp v0.0.0-20230817173708-d852ddb80c63
golang.org/x/net v0.17.0 // indirect
golang.org/x/sys v0.13.0
golang.org/x/sys v0.13.0 // indirect
golang.org/x/term v0.13.0
golang.org/x/text v0.13.0 // indirect
google.golang.org/protobuf v1.30.0 // indirect

View File

@@ -1,91 +0,0 @@
package gpu
import (
"bufio"
"fmt"
"io"
"log/slog"
"os"
"path/filepath"
"strconv"
"strings"
)
// TODO - windows vs. non-windows vs darwin
// Discovery logic for AMD/ROCm GPUs
const (
DriverVersionFile = "/sys/module/amdgpu/version"
GPUPropertiesFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/properties"
// TODO probably break these down per GPU to make the logic simpler
GPUTotalMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/properties" // size_in_bytes line
GPUUsedMemoryFileGlob = "/sys/class/kfd/kfd/topology/nodes/*/mem_banks/*/used_memory"
)
func AMDDetected() bool {
_, err := AMDDriverVersion()
return err == nil
}
func AMDDriverVersion() (string, error) {
_, err := os.Stat(DriverVersionFile)
if err != nil {
return "", err
}
fp, err := os.Open(DriverVersionFile)
if err != nil {
return "", err
}
defer fp.Close()
verString, err := io.ReadAll(fp)
if err != nil {
return "", err
}
return strings.TrimSpace(string(verString)), nil
}
func AMDGFXVersions() []Version {
res := []Version{}
matches, _ := filepath.Glob(GPUPropertiesFileGlob)
for _, match := range matches {
fp, err := os.Open(match)
if err != nil {
slog.Debug(fmt.Sprintf("failed to open sysfs node file %s: %s", match, err))
continue
}
defer fp.Close()
scanner := bufio.NewScanner(fp)
// optionally, resize scanner's capacity for lines over 64K, see next example
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.HasPrefix(line, "gfx_target_version") {
ver := strings.Fields(line)
if len(ver) != 2 || len(ver[1]) < 5 {
slog.Debug("malformed " + line)
continue
}
l := len(ver[1])
patch, err1 := strconv.ParseUint(ver[1][l-2:l], 10, 32)
minor, err2 := strconv.ParseUint(ver[1][l-4:l-2], 10, 32)
major, err3 := strconv.ParseUint(ver[1][:l-4], 10, 32)
if err1 != nil || err2 != nil || err3 != nil {
slog.Debug("malformed int " + line)
continue
}
res = append(res, Version{
Major: uint(major),
Minor: uint(minor),
Patch: uint(patch),
})
}
}
}
return res
}
func (v Version) ToGFXString() string {
return fmt.Sprintf("gfx%d%d%d", v.Major, v.Minor, v.Patch)
}

View File

@@ -1,21 +0,0 @@
package gpu
import (
"log/slog"
"golang.org/x/sys/cpu"
)
func GetCPUVariant() string {
if cpu.X86.HasAVX2 {
slog.Info("CPU has AVX2")
return "avx2"
}
if cpu.X86.HasAVX {
slog.Info("CPU has AVX")
return "avx"
}
slog.Info("CPU does not have vector extensions")
// else LCD
return ""
}

View File

@@ -12,12 +12,8 @@ package gpu
import "C"
import (
"fmt"
"log/slog"
"os"
"path/filepath"
"log"
"runtime"
"strconv"
"strings"
"sync"
"unsafe"
)
@@ -30,86 +26,34 @@ type handles struct {
var gpuMutex sync.Mutex
var gpuHandles *handles = nil
// With our current CUDA compile flags, older than 5.0 will not work properly
var CudaComputeMin = [2]C.int{5, 0}
// Possible locations for the nvidia-ml library
var CudaLinuxGlobs = []string{
"/usr/local/cuda/lib64/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so*",
"/usr/lib/wsl/lib/libnvidia-ml.so*",
"/usr/lib/wsl/drivers/*/libnvidia-ml.so*",
"/opt/cuda/lib64/libnvidia-ml.so*",
"/usr/lib*/libnvidia-ml.so*",
"/usr/local/lib*/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so*",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so*",
// TODO: are these stubs ever valid?
"/opt/cuda/targets/x86_64-linux/lib/stubs/libnvidia-ml.so*",
}
var CudaWindowsGlobs = []string{
"c:\\Windows\\System32\\nvml.dll",
}
var RocmLinuxGlobs = []string{
"/opt/rocm*/lib*/librocm_smi64.so*",
}
var RocmWindowsGlobs = []string{
"c:\\Windows\\System32\\rocm_smi64.dll",
}
// With our current CUDA compile flags, 5.2 and older will not work properly
const CudaComputeMajorMin = 6
// Note: gpuMutex must already be held
func initGPUHandles() {
// TODO - if the ollama build is CPU only, don't do these checks as they're irrelevant and confusing
log.Printf("Detecting GPU type")
gpuHandles = &handles{nil, nil}
var cudaMgmtName string
var cudaMgmtPatterns []string
var rocmMgmtName string
var rocmMgmtPatterns []string
switch runtime.GOOS {
case "windows":
cudaMgmtName = "nvml.dll"
cudaMgmtPatterns = make([]string, len(CudaWindowsGlobs))
copy(cudaMgmtPatterns, CudaWindowsGlobs)
rocmMgmtName = "rocm_smi64.dll"
rocmMgmtPatterns = make([]string, len(RocmWindowsGlobs))
copy(rocmMgmtPatterns, RocmWindowsGlobs)
case "linux":
cudaMgmtName = "libnvidia-ml.so"
cudaMgmtPatterns = make([]string, len(CudaLinuxGlobs))
copy(cudaMgmtPatterns, CudaLinuxGlobs)
rocmMgmtName = "librocm_smi64.so"
rocmMgmtPatterns = make([]string, len(RocmLinuxGlobs))
copy(rocmMgmtPatterns, RocmLinuxGlobs)
default:
return
}
var resp C.cuda_init_resp_t
C.cuda_init(&resp)
if resp.err != nil {
log.Printf("CUDA not detected: %s", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
slog.Info("Detecting GPU type")
cudaLibPaths := FindGPULibs(cudaMgmtName, cudaMgmtPatterns)
if len(cudaLibPaths) > 0 {
cuda := LoadCUDAMgmt(cudaLibPaths)
if cuda != nil {
slog.Info("Nvidia GPU detected")
gpuHandles.cuda = cuda
return
}
}
rocmLibPaths := FindGPULibs(rocmMgmtName, rocmMgmtPatterns)
if len(rocmLibPaths) > 0 {
rocm := LoadROCMMgmt(rocmLibPaths)
if rocm != nil {
slog.Info("Radeon GPU detected")
gpuHandles.rocm = rocm
return
var resp C.rocm_init_resp_t
C.rocm_init(&resp)
if resp.err != nil {
log.Printf("ROCm not detected: %s", C.GoString(resp.err))
C.free(unsafe.Pointer(resp.err))
} else {
log.Printf("Radeon GPU detected")
rocm := resp.rh
gpuHandles.rocm = &rocm
}
} else {
log.Printf("Nvidia GPU detected")
cuda := resp.ch
gpuHandles.cuda = &cuda
}
}
@@ -122,99 +66,47 @@ func GetGPUInfo() GpuInfo {
initGPUHandles()
}
// All our GPU builds on x86 have AVX enabled, so fallback to CPU if we don't detect at least AVX
cpuVariant := GetCPUVariant()
if cpuVariant == "" && runtime.GOARCH == "amd64" {
slog.Warn("CPU does not have AVX or AVX2, disabling GPU support.")
}
var memInfo C.mem_info_t
resp := GpuInfo{}
if gpuHandles.cuda != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
if gpuHandles.cuda != nil {
C.cuda_check_vram(*gpuHandles.cuda, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err)))
log.Printf("error looking up CUDA GPU memory: %s", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.count > 0 {
} else {
// Verify minimum compute capability
var cc C.cuda_compute_capability_t
C.cuda_compute_capability(*gpuHandles.cuda, &cc)
if cc.err != nil {
slog.Info(fmt.Sprintf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err)))
log.Printf("error looking up CUDA GPU compute capability: %s", C.GoString(cc.err))
C.free(unsafe.Pointer(cc.err))
} else if cc.major > CudaComputeMin[0] || (cc.major == CudaComputeMin[0] && cc.minor >= CudaComputeMin[1]) {
slog.Info(fmt.Sprintf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor))
} else if cc.major >= CudaComputeMajorMin {
log.Printf("CUDA Compute Capability detected: %d.%d", cc.major, cc.minor)
resp.Library = "cuda"
} else {
slog.Info(fmt.Sprintf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor))
log.Printf("CUDA GPU is too old. Falling back to CPU mode. Compute Capability detected: %d.%d", cc.major, cc.minor)
}
}
} else if AMDDetected() && gpuHandles.rocm != nil && (cpuVariant != "" || runtime.GOARCH != "amd64") {
ver, err := AMDDriverVersion()
if err == nil {
slog.Info("AMD Driver: " + ver)
}
gfx := AMDGFXVersions()
tooOld := false
for _, v := range gfx {
if v.Major < 9 {
slog.Info("AMD GPU too old, falling back to CPU " + v.ToGFXString())
tooOld = true
break
}
// TODO - remap gfx strings for unsupporetd minor/patch versions to supported for the same major
// e.g. gfx1034 works if we map it to gfx1030 at runtime
}
if !tooOld {
// TODO - this algo can be shifted over to use sysfs instead of the rocm info library...
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err)))
C.free(unsafe.Pointer(memInfo.err))
} else if memInfo.igpu_index >= 0 && memInfo.count == 1 {
// Only one GPU detected and it appears to be an integrated GPU - skip it
slog.Info("ROCm unsupported integrated GPU detected")
} else if memInfo.count > 0 {
if memInfo.igpu_index >= 0 {
// We have multiple GPUs reported, and one of them is an integrated GPU
// so we have to set the env var to bypass it
// If the user has specified their own ROCR_VISIBLE_DEVICES, don't clobber it
val := os.Getenv("ROCR_VISIBLE_DEVICES")
if val == "" {
devices := []string{}
for i := 0; i < int(memInfo.count); i++ {
if i == int(memInfo.igpu_index) {
continue
}
devices = append(devices, strconv.Itoa(i))
}
val = strings.Join(devices, ",")
os.Setenv("ROCR_VISIBLE_DEVICES", val)
}
slog.Info(fmt.Sprintf("ROCm integrated GPU detected - ROCR_VISIBLE_DEVICES=%s", val))
}
resp.Library = "rocm"
var version C.rocm_version_resp_t
C.rocm_get_version(*gpuHandles.rocm, &version)
verString := C.GoString(version.str)
if version.status == 0 {
resp.Variant = "v" + verString
} else {
slog.Info(fmt.Sprintf("failed to look up ROCm version: %s", verString))
}
C.free(unsafe.Pointer(version.str))
}
} else if gpuHandles.rocm != nil {
C.rocm_check_vram(*gpuHandles.rocm, &memInfo)
if memInfo.err != nil {
log.Printf("error looking up ROCm GPU memory: %s", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
} else {
resp.Library = "rocm"
}
}
if resp.Library == "" {
C.cpu_check_ram(&memInfo)
resp.Library = "cpu"
resp.Variant = cpuVariant
// In the future we may offer multiple CPU variants to tune CPU features
if runtime.GOOS == "windows" {
resp.Library = "cpu"
} else {
resp.Library = "default"
}
}
if memInfo.err != nil {
slog.Info(fmt.Sprintf("error looking up CPU memory: %s", C.GoString(memInfo.err)))
log.Printf("error looking up CPU memory: %s", C.GoString(memInfo.err))
C.free(unsafe.Pointer(memInfo.err))
return resp
}
@@ -241,111 +133,13 @@ func getCPUMem() (memInfo, error) {
func CheckVRAM() (int64, error) {
gpuInfo := GetGPUInfo()
if gpuInfo.FreeMemory > 0 && (gpuInfo.Library == "cuda" || gpuInfo.Library == "rocm") {
// leave 10% or 1024MiB of VRAM free per GPU to handle unaccounted for overhead
overhead := gpuInfo.FreeMemory / 10
gpus := uint64(gpuInfo.DeviceCount)
if overhead < gpus*1024*1024*1024 {
overhead = gpus * 1024 * 1024 * 1024
// leave 10% or 384Mi of VRAM free for unaccounted for overhead
overhead := gpuInfo.FreeMemory * uint64(gpuInfo.DeviceCount) / 10
if overhead < 384*1024*1024 {
overhead = 384 * 1024 * 1024
}
avail := int64(gpuInfo.FreeMemory - overhead)
slog.Debug(fmt.Sprintf("%s detected %d devices with %dM available memory", gpuInfo.Library, gpuInfo.DeviceCount, avail/1024/1024))
return avail, nil
return int64(gpuInfo.FreeMemory - overhead), nil
}
return 0, fmt.Errorf("no GPU detected") // TODO - better handling of CPU based memory determiniation
}
func FindGPULibs(baseLibName string, patterns []string) []string {
// Multiple GPU libraries may exist, and some may not work, so keep trying until we exhaust them
var ldPaths []string
gpuLibPaths := []string{}
slog.Info(fmt.Sprintf("Searching for GPU management library %s", baseLibName))
switch runtime.GOOS {
case "windows":
ldPaths = strings.Split(os.Getenv("PATH"), ";")
case "linux":
ldPaths = strings.Split(os.Getenv("LD_LIBRARY_PATH"), ":")
default:
return gpuLibPaths
}
// Start with whatever we find in the PATH/LD_LIBRARY_PATH
for _, ldPath := range ldPaths {
d, err := filepath.Abs(ldPath)
if err != nil {
continue
}
patterns = append(patterns, filepath.Join(d, baseLibName+"*"))
}
slog.Debug(fmt.Sprintf("gpu management search paths: %v", patterns))
for _, pattern := range patterns {
// Ignore glob discovery errors
matches, _ := filepath.Glob(pattern)
for _, match := range matches {
// Resolve any links so we don't try the same lib multiple times
// and weed out any dups across globs
libPath := match
tmp := match
var err error
for ; err == nil; tmp, err = os.Readlink(libPath) {
if !filepath.IsAbs(tmp) {
tmp = filepath.Join(filepath.Dir(libPath), tmp)
}
libPath = tmp
}
new := true
for _, cmp := range gpuLibPaths {
if cmp == libPath {
new = false
break
}
}
if new {
gpuLibPaths = append(gpuLibPaths, libPath)
}
}
}
slog.Info(fmt.Sprintf("Discovered GPU libraries: %v", gpuLibPaths))
return gpuLibPaths
}
func LoadCUDAMgmt(cudaLibPaths []string) *C.cuda_handle_t {
var resp C.cuda_init_resp_t
resp.ch.verbose = getVerboseState()
for _, libPath := range cudaLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.cuda_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load CUDA management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.ch
}
}
return nil
}
func LoadROCMMgmt(rocmLibPaths []string) *C.rocm_handle_t {
var resp C.rocm_init_resp_t
resp.rh.verbose = getVerboseState()
for _, libPath := range rocmLibPaths {
lib := C.CString(libPath)
defer C.free(unsafe.Pointer(lib))
C.rocm_init(lib, &resp)
if resp.err != nil {
slog.Info(fmt.Sprintf("Unable to load ROCm management library %s: %s", libPath, C.GoString(resp.err)))
C.free(unsafe.Pointer(resp.err))
} else {
return &resp.rh
}
}
return nil
}
func getVerboseState() C.uint16_t {
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
return C.uint16_t(1)
}
return C.uint16_t(0)
}

View File

@@ -32,15 +32,8 @@ func CheckVRAM() (int64, error) {
func GetGPUInfo() GpuInfo {
mem, _ := getCPUMem()
if runtime.GOARCH == "amd64" {
return GpuInfo{
Library: "cpu",
Variant: GetCPUVariant(),
memInfo: mem,
}
}
return GpuInfo{
Library: "metal",
Library: "default",
memInfo: mem,
}
}
@@ -52,3 +45,7 @@ func getCPUMem() (memInfo, error) {
DeviceCount: 0,
}, nil
}
func nativeInit() error {
return nil
}

View File

@@ -27,13 +27,6 @@
#endif
#define LOG(verbose, ...) \
do { \
if (verbose) { \
fprintf(stderr, __VA_ARGS__); \
} \
} while (0)
#ifdef __cplusplus
extern "C" {
#endif
@@ -42,7 +35,6 @@ typedef struct mem_info {
uint64_t total;
uint64_t free;
unsigned int count;
int igpu_index; // If >= 0, we detected an integrated GPU to ignore
char *err; // If non-nill, caller responsible for freeing
} mem_info_t;

View File

@@ -4,7 +4,33 @@
#include <string.h>
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
#ifndef _WIN32
const char *cuda_lib_paths[] = {
"libnvidia-ml.so",
"/usr/lib/wsl/lib/libnvidia-ml.so", // TODO Maybe glob?
"/usr/lib/wsl/lib/libnvidia-ml.so.1",
"/usr/local/cuda/lib64/libnvidia-ml.so",
"/usr/lib/libnvidia-ml.so",
"/usr/lib/libnvidia-ml.so.1",
"/usr/lib/x86_64-linux-gnu/nvidia/current/libnvidia-ml.so",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so",
"/usr/lib/x86_64-linux-gnu/libnvidia-ml.so.1",
"/usr/lib/aarch64-linux-gnu/nvidia/current/libnvidia-ml.so",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so",
"/usr/lib/aarch64-linux-gnu/libnvidia-ml.so.1",
NULL,
};
#else
const char *cuda_lib_paths[] = {
"nvml.dll",
"",
NULL,
};
#endif
#define CUDA_LOOKUP_SIZE 6
void cuda_init(cuda_init_resp_t *resp) {
nvmlReturn_t ret;
resp->err = NULL;
const int buflen = 256;
@@ -14,47 +40,36 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
struct lookup {
char *s;
void **p;
} l[] = {
{"nvmlInit_v2", (void *)&resp->ch.nvmlInit_v2},
{"nvmlShutdown", (void *)&resp->ch.nvmlShutdown},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.nvmlDeviceGetHandleByIndex},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.nvmlDeviceGetMemoryInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.nvmlDeviceGetCount_v2},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.nvmlDeviceGetCudaComputeCapability},
{"nvmlSystemGetDriverVersion", (void *)&resp->ch.nvmlSystemGetDriverVersion},
{"nvmlDeviceGetName", (void *)&resp->ch.nvmlDeviceGetName},
{"nvmlDeviceGetSerial", (void *)&resp->ch.nvmlDeviceGetSerial},
{"nvmlDeviceGetVbiosVersion", (void *)&resp->ch.nvmlDeviceGetVbiosVersion},
{"nvmlDeviceGetBoardPartNumber", (void *)&resp->ch.nvmlDeviceGetBoardPartNumber},
{"nvmlDeviceGetBrand", (void *)&resp->ch.nvmlDeviceGetBrand},
{NULL, NULL},
} l[CUDA_LOOKUP_SIZE] = {
{"nvmlInit_v2", (void *)&resp->ch.initFn},
{"nvmlShutdown", (void *)&resp->ch.shutdownFn},
{"nvmlDeviceGetHandleByIndex", (void *)&resp->ch.getHandle},
{"nvmlDeviceGetMemoryInfo", (void *)&resp->ch.getMemInfo},
{"nvmlDeviceGetCount_v2", (void *)&resp->ch.getCount},
{"nvmlDeviceGetCudaComputeCapability", (void *)&resp->ch.getComputeCapability},
};
resp->ch.handle = LOAD_LIBRARY(cuda_lib_path, RTLD_LAZY);
for (i = 0; cuda_lib_paths[i] != NULL && resp->ch.handle == NULL; i++) {
resp->ch.handle = LOAD_LIBRARY(cuda_lib_paths[i], RTLD_LAZY);
}
if (!resp->ch.handle) {
// TODO improve error message, as the LOAD_ERR will have typically have the
// final path that was checked which might be confusing.
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "library %s load err: %s\n", cuda_lib_path, msg);
snprintf(buf, buflen,
"Unable to load %s library to query for Nvidia GPUs: %s",
cuda_lib_path, msg);
cuda_lib_paths[0], msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "wiring nvidia management library functions in %s\n", cuda_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->ch.verbose, "dlsym: %s\n", l[i].s);
for (i = 0; i < CUDA_LOOKUP_SIZE; i++) { // TODO - fix this to use a null terminated list
*l[i].p = LOAD_SYMBOL(resp->ch.handle, l[i].s);
if (!l[i].p) {
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->ch.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->ch.handle);
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
@@ -63,23 +78,13 @@ void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp) {
}
}
ret = (*resp->ch.nvmlInit_v2)();
ret = (*resp->ch.initFn)();
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlInit_v2 err: %d\n", ret);
UNLOAD_LIBRARY(resp->ch.handle);
resp->ch.handle = NULL;
snprintf(buf, buflen, "nvml vram init failure: %d", ret);
resp->err = strdup(buf);
return;
}
// Report driver version if we're in verbose mode, ignore errors
ret = (*resp->ch.nvmlSystemGetDriverVersion)(buf, buflen);
if (ret != NVML_SUCCESS) {
LOG(resp->ch.verbose, "nvmlSystemGetDriverVersion failed: %d\n", ret);
} else {
LOG(resp->ch.verbose, "CUDA driver version: %s\n", buf);
}
return;
}
void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
@@ -96,7 +101,7 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
return;
}
ret = (*h.nvmlDeviceGetCount_v2)(&resp->count);
ret = (*h.getCount)(&resp->count);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
@@ -106,57 +111,19 @@ void cuda_check_vram(cuda_handle_t h, mem_info_t *resp) {
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
ret = (*h.getHandle)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetMemoryInfo)(device, &memInfo);
ret = (*h.getMemInfo)(device, &memInfo);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device memory info lookup failure %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
if (h.verbose) {
nvmlBrandType_t brand = 0;
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.nvmlDeviceGetName)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetName failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA device name: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBoardPartNumber)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBoardPartNumber failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA part number: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetSerial)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetSerial failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA S/N: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetVbiosVersion)(device, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetVbiosVersion failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA vbios version: %s\n", i, buf);
}
ret = (*h.nvmlDeviceGetBrand)(device, &brand);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "nvmlDeviceGetBrand failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] CUDA brand: %d\n", i, brand);
}
}
LOG(h.verbose, "[%d] CUDA totalMem %ld\n", i, memInfo.total);
LOG(h.verbose, "[%d] CUDA usedMem %ld\n", i, memInfo.free);
resp->total += memInfo.total;
resp->free += memInfo.free;
@@ -181,7 +148,7 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
}
unsigned int devices;
ret = (*h.nvmlDeviceGetCount_v2)(&devices);
ret = (*h.getCount)(&devices);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
resp->err = strdup(buf);
@@ -189,14 +156,14 @@ void cuda_compute_capability(cuda_handle_t h, cuda_compute_capability_t *resp) {
}
for (i = 0; i < devices; i++) {
ret = (*h.nvmlDeviceGetHandleByIndex)(i, &device);
ret = (*h.getHandle)(i, &device);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "unable to get device handle %d: %d", i, ret);
resp->err = strdup(buf);
return;
}
ret = (*h.nvmlDeviceGetCudaComputeCapability)(device, &major, &minor);
ret = (*h.getComputeCapability)(device, &major, &minor);
if (ret != NVML_SUCCESS) {
snprintf(buf, buflen, "device compute capability lookup failure %d: %d", i, ret);
resp->err = strdup(buf);

View File

@@ -15,26 +15,14 @@ typedef struct nvmlMemory_st {
unsigned long long used;
} nvmlMemory_t;
typedef enum nvmlBrandType_enum
{
NVML_BRAND_UNKNOWN = 0,
} nvmlBrandType_t;
typedef struct cuda_handle {
void *handle;
uint16_t verbose;
nvmlReturn_t (*nvmlInit_v2)(void);
nvmlReturn_t (*nvmlShutdown)(void);
nvmlReturn_t (*nvmlDeviceGetHandleByIndex)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*nvmlDeviceGetMemoryInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*nvmlDeviceGetCount_v2)(unsigned int *);
nvmlReturn_t (*nvmlDeviceGetCudaComputeCapability)(nvmlDevice_t, int* major, int* minor);
nvmlReturn_t (*nvmlSystemGetDriverVersion) (char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetName) (nvmlDevice_t device, char* name, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetSerial) (nvmlDevice_t device, char* serial, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetVbiosVersion) (nvmlDevice_t device, char* version, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBoardPartNumber) (nvmlDevice_t device, char* partNumber, unsigned int length);
nvmlReturn_t (*nvmlDeviceGetBrand) (nvmlDevice_t device, nvmlBrandType_t* type);
nvmlReturn_t (*initFn)(void);
nvmlReturn_t (*shutdownFn)(void);
nvmlReturn_t (*getHandle)(unsigned int, nvmlDevice_t *);
nvmlReturn_t (*getMemInfo)(nvmlDevice_t, nvmlMemory_t *);
nvmlReturn_t (*getCount)(unsigned int *);
nvmlReturn_t (*getComputeCapability)(nvmlDevice_t, int* major, int* minor);
} cuda_handle_t;
typedef struct cuda_init_resp {
@@ -48,7 +36,7 @@ typedef struct cuda_compute_capability {
int minor;
} cuda_compute_capability_t;
void cuda_init(char *cuda_lib_path, cuda_init_resp_t *resp);
void cuda_init(cuda_init_resp_t *resp);
void cuda_check_vram(cuda_handle_t ch, mem_info_t *resp);
void cuda_compute_capability(cuda_handle_t ch, cuda_compute_capability_t *cc);

View File

@@ -4,7 +4,22 @@
#include <string.h>
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
#ifndef _WIN32
const char *rocm_lib_paths[] = {
"librocm_smi64.so",
"/opt/rocm/lib/librocm_smi64.so",
NULL,
};
#else
// TODO untested
const char *rocm_lib_paths[] = {
"rocm_smi64.dll",
"/opt/rocm/lib/rocm_smi64.dll",
NULL,
};
#endif
void rocm_init(rocm_init_resp_t *resp) {
rsmi_status_t ret;
resp->err = NULL;
const int buflen = 256;
@@ -13,48 +28,32 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
struct lookup {
char *s;
void **p;
} l[] = {
{"rsmi_init", (void *)&resp->rh.rsmi_init},
{"rsmi_shut_down", (void *)&resp->rh.rsmi_shut_down},
{"rsmi_dev_memory_total_get", (void *)&resp->rh.rsmi_dev_memory_total_get},
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.rsmi_dev_memory_usage_get},
{"rsmi_version_get", (void *)&resp->rh.rsmi_version_get},
{"rsmi_num_monitor_devices", (void*)&resp->rh.rsmi_num_monitor_devices},
{"rsmi_dev_id_get", (void*)&resp->rh.rsmi_dev_id_get},
{"rsmi_dev_name_get", (void *)&resp->rh.rsmi_dev_name_get},
{"rsmi_dev_brand_get", (void *)&resp->rh.rsmi_dev_brand_get},
{"rsmi_dev_vendor_name_get", (void *)&resp->rh.rsmi_dev_vendor_name_get},
{"rsmi_dev_vram_vendor_get", (void *)&resp->rh.rsmi_dev_vram_vendor_get},
{"rsmi_dev_serial_number_get", (void *)&resp->rh.rsmi_dev_serial_number_get},
{"rsmi_dev_subsystem_name_get", (void *)&resp->rh.rsmi_dev_subsystem_name_get},
{"rsmi_dev_vbios_version_get", (void *)&resp->rh.rsmi_dev_vbios_version_get},
{NULL, NULL},
} l[4] = {
{"rsmi_init", (void *)&resp->rh.initFn},
{"rsmi_shut_down", (void *)&resp->rh.shutdownFn},
{"rsmi_dev_memory_total_get", (void *)&resp->rh.totalMemFn},
{"rsmi_dev_memory_usage_get", (void *)&resp->rh.usageMemFn},
// { "rsmi_dev_id_get", (void*)&resp->rh.getHandle },
};
resp->rh.handle = LOAD_LIBRARY(rocm_lib_path, RTLD_LAZY);
for (i = 0; rocm_lib_paths[i] != NULL && resp->rh.handle == NULL; i++) {
resp->rh.handle = LOAD_LIBRARY(rocm_lib_paths[i], RTLD_LAZY);
}
if (!resp->rh.handle) {
char *msg = LOAD_ERR();
snprintf(buf, buflen,
"Unable to load %s library to query for Radeon GPUs: %s\n",
rocm_lib_path, msg);
rocm_lib_paths[0], msg);
free(msg);
resp->err = strdup(buf);
return;
}
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "wiring rocm management library functions in %s\n", rocm_lib_path);
for (i = 0; l[i].s != NULL; i++) {
// TODO once we've squashed the remaining corner cases remove this log
LOG(resp->rh.verbose, "dlsym: %s\n", l[i].s);
for (i = 0; i < 4; i++) {
*l[i].p = LOAD_SYMBOL(resp->rh.handle, l[i].s);
if (!l[i].p) {
resp->rh.handle = NULL;
char *msg = LOAD_ERR();
LOG(resp->rh.verbose, "dlerr: %s\n", msg);
UNLOAD_LIBRARY(resp->rh.handle);
char *msg = LOAD_ERR();
snprintf(buf, buflen, "symbol lookup for %s failed: %s", l[i].s,
msg);
free(msg);
@@ -63,11 +62,8 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
}
}
ret = (*resp->rh.rsmi_init)(0);
ret = (*resp->rh.initFn)(0);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(resp->rh.verbose, "rsmi_init err: %d\n", ret);
UNLOAD_LIBRARY(resp->rh.handle);
resp->rh.handle = NULL;
snprintf(buf, buflen, "rocm vram init failure: %d", ret);
resp->err = strdup(buf);
}
@@ -77,7 +73,8 @@ void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp) {
void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
resp->err = NULL;
resp->igpu_index = -1;
// uint32_t num_devices;
// uint16_t device;
uint64_t totalMem = 0;
uint64_t usedMem = 0;
rsmi_status_t ret;
@@ -86,113 +83,38 @@ void rocm_check_vram(rocm_handle_t h, mem_info_t *resp) {
int i;
if (h.handle == NULL) {
resp->err = strdup("rocm handle not initialized");
resp->err = strdup("nvml handle sn't initialized");
return;
}
ret = (*h.rsmi_num_monitor_devices)(&resp->count);
// TODO - iterate through devices... ret =
// rsmi_num_monitor_devices(&num_devices);
// ret = (*h.getHandle)(0, &device);
// if (ret != RSMI_STATUS_SUCCESS) {
// printf("rocm vram device lookup failure: %d\n", ret);
// return -1;
// }
// Get total memory - used memory for available memory
ret = (*h.totalMemFn)(0, RSMI_MEM_TYPE_VRAM, &totalMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "unable to get device count: %d", ret);
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "discovered %d ROCm GPU Devices\n", resp->count);
resp->total = 0;
resp->free = 0;
for (i = 0; i < resp->count; i++) {
if (h.verbose) {
// When in verbose mode, report more information about
// the card we discover, but don't fail on error
ret = (*h.rsmi_dev_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm device name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_brand_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_brand_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm brand: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vendor_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vendor_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vram_vendor_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vram_vendor_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm VRAM vendor: %s\n", i, buf);
}
ret = (*h.rsmi_dev_serial_number_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_serial_number_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm S/N: %s\n", i, buf);
}
ret = (*h.rsmi_dev_subsystem_name_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_subsystem_name_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm subsystem name: %s\n", i, buf);
}
ret = (*h.rsmi_dev_vbios_version_get)(i, buf, buflen);
if (ret != RSMI_STATUS_SUCCESS) {
LOG(h.verbose, "rsmi_dev_vbios_version_get failed: %d\n", ret);
} else {
LOG(h.verbose, "[%d] ROCm vbios version: %s\n", i, buf);
}
}
// Get total memory - used memory for available memory
ret = (*h.rsmi_dev_memory_total_get)(i, RSMI_MEM_TYPE_VRAM, &totalMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm total mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
ret = (*h.rsmi_dev_memory_usage_get)(i, RSMI_MEM_TYPE_VRAM, &usedMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
LOG(h.verbose, "[%d] ROCm totalMem %ld\n", i, totalMem);
LOG(h.verbose, "[%d] ROCm usedMem %ld\n", i, usedMem);
if (totalMem < 1024 * 1024 * 1024) {
// Do not add up integrated GPU memory capacity, it's a bogus 512M, and actually uses system memory
LOG(h.verbose, "[%d] ROCm integrated GPU\n", i);
resp->igpu_index = i;
} else {
resp->total += totalMem;
resp->free += totalMem - usedMem;
}
}
}
void rocm_get_version(rocm_handle_t h, rocm_version_resp_t *resp) {
const int buflen = 256;
char buf[buflen + 1];
if (h.handle == NULL) {
resp->str = strdup("rocm handle not initialized");
resp->status = 1;
ret = (*h.usageMemFn)(0, RSMI_MEM_TYPE_VRAM, &usedMem);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "rocm usage mem lookup failure: %d", ret);
resp->err = strdup(buf);
return;
}
rsmi_version_t ver;
rsmi_status_t ret;
ret = h.rsmi_version_get(&ver);
if (ret != RSMI_STATUS_SUCCESS) {
snprintf(buf, buflen, "unexpected response on version lookup %d", ret);
resp->status = 1;
} else {
snprintf(buf, buflen, "%d", ver.major);
resp->status = 0;
}
resp->str = strdup(buf);
// TODO: set this to the actual number of devices
resp->count = 1;
resp->total = totalMem;
resp->free = totalMem - usedMem;
return;
}
#endif // __APPLE__
#endif // __APPLE__

View File

@@ -15,30 +15,13 @@ typedef enum rsmi_memory_type {
RSMI_MEM_TYPE_GTT,
} rsmi_memory_type_t;
typedef struct {
uint32_t major;
uint32_t minor;
uint32_t patch;
const char *build;
} rsmi_version_t;
typedef struct rocm_handle {
void *handle;
uint16_t verbose;
rsmi_status_t (*rsmi_init)(uint64_t);
rsmi_status_t (*rsmi_shut_down)(void);
rsmi_status_t (*rsmi_dev_memory_total_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*rsmi_dev_memory_usage_get)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*rsmi_version_get) (rsmi_version_t *version);
rsmi_status_t (*rsmi_num_monitor_devices) (uint32_t *);
rsmi_status_t (*rsmi_dev_id_get)(uint32_t, uint16_t *);
rsmi_status_t (*rsmi_dev_name_get) (uint32_t,char *,size_t);
rsmi_status_t (*rsmi_dev_brand_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vendor_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vram_vendor_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_serial_number_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_subsystem_name_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*rsmi_dev_vbios_version_get) (uint32_t, char *, uint32_t);
rsmi_status_t (*initFn)(uint64_t);
rsmi_status_t (*shutdownFn)(void);
rsmi_status_t (*totalMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
rsmi_status_t (*usageMemFn)(uint32_t, rsmi_memory_type_t, uint64_t *);
// rsmi_status_t (*getHandle)(uint32_t, uint16_t *);
} rocm_handle_t;
typedef struct rocm_init_resp {
@@ -46,14 +29,8 @@ typedef struct rocm_init_resp {
rocm_handle_t rh;
} rocm_init_resp_t;
typedef struct rocm_version_resp {
rsmi_status_t status;
char *str; // Contains version or error string if status != 0
} rocm_version_resp_t;
void rocm_init(char *rocm_lib_path, rocm_init_resp_t *resp);
void rocm_init(rocm_init_resp_t *resp);
void rocm_check_vram(rocm_handle_t rh, mem_info_t *resp);
void rocm_get_version(rocm_handle_t rh, rocm_version_resp_t *resp);
#endif // __GPU_INFO_ROCM_H__
#endif // __APPLE__

View File

@@ -9,7 +9,7 @@ import (
func TestBasicGetGPUInfo(t *testing.T) {
info := GetGPUInfo()
assert.Contains(t, "cuda rocm cpu metal", info.Library)
assert.Contains(t, "cuda rocm cpu default", info.Library)
switch runtime.GOOS {
case "darwin":
@@ -18,7 +18,7 @@ func TestBasicGetGPUInfo(t *testing.T) {
case "linux", "windows":
assert.Greater(t, info.TotalMemory, uint64(0))
assert.Greater(t, info.FreeMemory, uint64(0))
assert.Greater(t, info.DeviceCount, uint32(0))
assert.Greater(t, info.DeviceCount, uint64(0))
default:
return
}

View File

@@ -11,14 +11,5 @@ type GpuInfo struct {
memInfo
Library string `json:"library,omitempty"`
// Optional variant to select (e.g. versions, cpu feature flags)
Variant string `json:"variant,omitempty"`
// TODO add other useful attributes about the card here for discovery information
}
type Version struct {
Major uint
Minor uint
Patch uint
}

View File

@@ -1,11 +1,11 @@
#include "dyn_ext_server.h"
#include "dynamic_shim.h"
#include <stdio.h>
#include <string.h>
#ifdef __linux__
#include <dlfcn.h>
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags)
#define LOAD_LIBRARY(lib, flags) dlopen(lib, flags | RTLD_DEEPBIND)
#define LOAD_SYMBOL(handle, sym) dlsym(handle, sym)
#define LOAD_ERR() strdup(dlerror())
#define UNLOAD_LIBRARY(handle) dlclose(handle)
@@ -33,7 +33,7 @@ inline char *LOAD_ERR() {
#define UNLOAD_LIBRARY(handle) dlclose(handle)
#endif
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
ext_server_resp_t *err) {
int i = 0;
struct lookup {
@@ -58,8 +58,8 @@ void dyn_init(const char *libPath, struct dynamic_llama_server *s,
{"", NULL},
};
printf("loading library %s\n", libPath);
s->handle = LOAD_LIBRARY(libPath, RTLD_LOCAL|RTLD_NOW);
printf("Lazy loading %s library\n", libPath);
s->handle = LOAD_LIBRARY(libPath, RTLD_NOW);
if (!s->handle) {
err->id = -1;
char *msg = LOAD_ERR();
@@ -83,63 +83,63 @@ void dyn_init(const char *libPath, struct dynamic_llama_server *s,
}
}
inline void dyn_llama_server_init(struct dynamic_llama_server s,
inline void dynamic_shim_llama_server_init(struct dynamic_llama_server s,
ext_server_params_t *sparams,
ext_server_resp_t *err) {
s.llama_server_init(sparams, err);
}
inline void dyn_llama_server_start(struct dynamic_llama_server s) {
inline void dynamic_shim_llama_server_start(struct dynamic_llama_server s) {
s.llama_server_start();
}
inline void dyn_llama_server_stop(struct dynamic_llama_server s) {
inline void dynamic_shim_llama_server_stop(struct dynamic_llama_server s) {
s.llama_server_stop();
}
inline void dyn_llama_server_completion(struct dynamic_llama_server s,
inline void dynamic_shim_llama_server_completion(struct dynamic_llama_server s,
const char *json_req,
ext_server_resp_t *resp) {
s.llama_server_completion(json_req, resp);
}
inline void dyn_llama_server_completion_next_result(
inline void dynamic_shim_llama_server_completion_next_result(
struct dynamic_llama_server s, const int task_id,
ext_server_task_result_t *result) {
s.llama_server_completion_next_result(task_id, result);
}
inline void dyn_llama_server_completion_cancel(
inline void dynamic_shim_llama_server_completion_cancel(
struct dynamic_llama_server s, const int task_id, ext_server_resp_t *err) {
s.llama_server_completion_cancel(task_id, err);
}
inline void dyn_llama_server_release_task_result(
inline void dynamic_shim_llama_server_release_task_result(
struct dynamic_llama_server s, ext_server_task_result_t *result) {
s.llama_server_release_task_result(result);
}
inline void dyn_llama_server_tokenize(struct dynamic_llama_server s,
inline void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s,
const char *json_req,
char **json_resp,
ext_server_resp_t *err) {
s.llama_server_tokenize(json_req, json_resp, err);
}
inline void dyn_llama_server_detokenize(struct dynamic_llama_server s,
inline void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s,
const char *json_req,
char **json_resp,
ext_server_resp_t *err) {
s.llama_server_detokenize(json_req, json_resp, err);
}
inline void dyn_llama_server_embedding(struct dynamic_llama_server s,
inline void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s,
const char *json_req,
char **json_resp,
ext_server_resp_t *err) {
s.llama_server_embedding(json_req, json_resp, err);
}
inline void dyn_llama_server_release_json_resp(
inline void dynamic_shim_llama_server_release_json_resp(
struct dynamic_llama_server s, char **json_resp) {
s.llama_server_release_json_resp(json_resp);
}

View File

@@ -27,46 +27,46 @@ struct dynamic_llama_server {
void (*llama_server_release_json_resp)(char **json_resp);
};
void dyn_init(const char *libPath, struct dynamic_llama_server *s,
void dynamic_shim_init(const char *libPath, struct dynamic_llama_server *s,
ext_server_resp_t *err);
// No good way to call C function pointers from Go so inline the indirection
void dyn_llama_server_init(struct dynamic_llama_server s,
void dynamic_shim_llama_server_init(struct dynamic_llama_server s,
ext_server_params_t *sparams,
ext_server_resp_t *err);
void dyn_llama_server_start(struct dynamic_llama_server s);
void dynamic_shim_llama_server_start(struct dynamic_llama_server s);
void dyn_llama_server_stop(struct dynamic_llama_server s);
void dynamic_shim_llama_server_stop(struct dynamic_llama_server s);
void dyn_llama_server_completion(struct dynamic_llama_server s,
void dynamic_shim_llama_server_completion(struct dynamic_llama_server s,
const char *json_req,
ext_server_resp_t *resp);
void dyn_llama_server_completion_next_result(
void dynamic_shim_llama_server_completion_next_result(
struct dynamic_llama_server s, const int task_id,
ext_server_task_result_t *result);
void dyn_llama_server_completion_cancel(struct dynamic_llama_server s,
void dynamic_shim_llama_server_completion_cancel(struct dynamic_llama_server s,
const int task_id,
ext_server_resp_t *err);
void dyn_llama_server_release_task_result(
void dynamic_shim_llama_server_release_task_result(
struct dynamic_llama_server s, ext_server_task_result_t *result);
void dyn_llama_server_tokenize(struct dynamic_llama_server s,
void dynamic_shim_llama_server_tokenize(struct dynamic_llama_server s,
const char *json_req, char **json_resp,
ext_server_resp_t *err);
void dyn_llama_server_detokenize(struct dynamic_llama_server s,
void dynamic_shim_llama_server_detokenize(struct dynamic_llama_server s,
const char *json_req,
char **json_resp,
ext_server_resp_t *err);
void dyn_llama_server_embedding(struct dynamic_llama_server s,
void dynamic_shim_llama_server_embedding(struct dynamic_llama_server s,
const char *json_req, char **json_resp,
ext_server_resp_t *err);
void dyn_llama_server_release_json_resp(struct dynamic_llama_server s,
void dynamic_shim_llama_server_release_json_resp(struct dynamic_llama_server s,
char **json_resp);
#ifdef __cplusplus

View File

@@ -2,24 +2,28 @@
set(TARGET ext_server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
if (WIN32)
add_library(${TARGET} SHARED ../../../ext_server/ext_server.cpp ../../llama.cpp)
else()
add_library(${TARGET} STATIC ../../../ext_server/ext_server.cpp ../../llama.cpp)
endif()
add_library(${TARGET} STATIC ../../../ext_server/ext_server.cpp)
target_include_directories(${TARGET} PRIVATE ../../common)
target_include_directories(${TARGET} PRIVATE ../..)
target_include_directories(${TARGET} PRIVATE ../../..)
target_compile_features(${TARGET} PRIVATE cxx_std_11)
target_compile_definitions(${TARGET} PUBLIC LLAMA_SERVER_LIBRARY=1)
target_link_libraries(${TARGET} PRIVATE ggml llava common )
set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>)
install(TARGETS ext_server LIBRARY)
target_link_libraries(${TARGET} PRIVATE common llama llava ${CMAKE_THREAD_LIBS_INIT})
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
if (BUILD_SHARED_LIBS)
set_target_properties(ext_server PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(ext_server PRIVATE LLAMA_SHARED LLAMA_BUILD)
add_library(ext_server_shared SHARED $<TARGET_OBJECTS:ext_server>)
target_link_libraries(ext_server_shared PRIVATE ggml llama llava common ${CMAKE_THREAD_LIBS_INIT})
install(TARGETS ext_server_shared LIBRARY)
endif()
if (CUDAToolkit_FOUND)
target_include_directories(${TARGET} PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
if (WIN32)
target_link_libraries(${TARGET} PRIVATE nvml)
target_link_libraries(ext_server_shared PRIVATE nvml)
endif()
endif()

View File

@@ -1,18 +1,4 @@
# Extern C Server
This directory contains a thin facade we layer on top of the Llama.cpp server to
expose `extern C` interfaces to access the functionality through direct API
calls in-process. The llama.cpp code uses compile time macros to configure GPU
type along with other settings. During the `go generate ./...` execution, the
build will generate one or more copies of the llama.cpp `extern C` server based
on what GPU libraries are detected to support multiple GPU types as well as CPU
only support. The Ollama go build then embeds these different servers to support
different GPUs and settings at runtime.
If you are making changes to the code in this directory, make sure to disable
caching during your go build to ensure you pick up your changes. A typical
iteration cycle from the top of the source tree looks like:
```
go generate ./... && go build -a .
```
This directory contains a thin facade we layer on top of the Llama.cpp server
to expose `extern C` interfaces to access the functionality through direct API calls in-process

View File

@@ -1,63 +1,24 @@
#include "ext_server.h"
#include <atomic>
// Necessary evil since the server types are not defined in a header
#include "server.cpp"
// Low level API access to verify GPU access
#if defined(GGML_USE_CUBLAS)
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define cudaGetDevice hipGetDevice
#define cudaError_t hipError_t
#define cudaSuccess hipSuccess
#define cudaGetErrorString hipGetErrorString
#else
#include <cuda_runtime.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#endif // defined(GGML_USE_HIPBLAS)
#endif // GGML_USE_CUBLAS
// Expose the llama server as a callable extern "C" API
llama_server_context *llama = NULL;
std::atomic<bool> ext_server_running(false);
std::thread ext_server_thread;
bool shutting_down = false;
std::atomic_int recv_counter;
// RAII wrapper for tracking in-flight recv calls
class atomicRecv {
public:
atomicRecv(std::atomic<int> &atomic) : atomic(atomic) {
++this->atomic;
}
~atomicRecv() {
--this->atomic;
}
private:
std::atomic<int> &atomic;
};
void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
recv_counter = 0;
#if SERVER_VERBOSE != 1
log_disable();
#endif
LOG_TEE("system info: %s", llama_print_system_info());
assert(err != NULL && sparams != NULL);
log_set_target(stderr);
if (!sparams->verbose_logging) {
server_verbose = true;
log_disable();
}
LOG_TEE("system info: %s\n", llama_print_system_info());
err->id = 0;
err->msg[0] = '\0';
try {
llama = new llama_server_context;
log_set_target(stdout);
gpt_params params;
params.n_ctx = sparams->n_ctx;
params.n_batch = sparams->n_batch;
@@ -86,31 +47,15 @@ void llama_server_init(ext_server_params *sparams, ext_server_resp_t *err) {
params.model = sparams->model;
}
if (sparams->lora_adapters != NULL) {
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
la = la->next) {
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
}
params.use_mmap = false;
for (ext_server_lora_adapter *la = sparams->lora_adapters; la != NULL;
la = la->next) {
params.lora_adapter.push_back(std::make_tuple(la->adapter, la->scale));
}
if (sparams->mmproj != NULL) {
params.mmproj = std::string(sparams->mmproj);
}
#if defined(GGML_USE_CUBLAS)
// Before attempting to init the backend which will assert on error, verify the CUDA/ROCM GPU is accessible
LOG_TEE("Performing pre-initialization of GPU\n");
int id;
cudaError_t cudaErr = cudaGetDevice(&id);
if (cudaErr != cudaSuccess) {
err->id = -1;
snprintf(err->msg, err->msg_len, "Unable to init GPU: %s", cudaGetErrorString(cudaErr));
return;
}
#endif
llama_backend_init(params.numa);
// load the model
@@ -139,23 +84,18 @@ void llama_server_start() {
assert(llama != NULL);
// TODO mutex to protect thread creation
ext_server_thread = std::thread([&]() {
ext_server_running = true;
try {
LOG_TEE("llama server main loop starting\n");
ggml_time_init();
llama->queue_tasks.on_new_task(std::bind(
&llama_server_context::process_single_task, llama, std::placeholders::_1));
llama->queue_tasks.on_finish_multitask(std::bind(
&llama_server_context::on_finish_multitask, llama, std::placeholders::_1));
llama->queue_tasks.on_all_tasks_finished(std::bind(
&llama_server_context::run_on_all_tasks_finished, llama));
llama->queue_results.on_multitask_update(std::bind(
&llama_server_queue::update_multitask,
&llama->queue_tasks,
std::placeholders::_1,
std::placeholders::_2,
std::placeholders::_3
));
llama->queue_tasks.start_loop();
while (ext_server_running.load()) {
if (!llama->update_slots()) {
LOG_TEE(
"unexpected error in llama server update_slots - exiting main "
"loop\n");
break;
}
}
} catch (std::exception &e) {
LOG_TEE("caught exception in llama server main loop: %s\n", e.what());
} catch (...) {
@@ -168,22 +108,13 @@ void llama_server_start() {
void llama_server_stop() {
assert(llama != NULL);
// Shutdown any in-flight requests and block incoming requests.
LOG_TEE("\ninitiating shutdown - draining remaining tasks...\n");
shutting_down = true;
while (recv_counter.load() > 0) {
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
// This may take a while for any pending tasks to drain
// TODO - consider a timeout to cancel tasks if it's taking too long
llama->queue_tasks.terminate();
// TODO - too verbose, remove once things are solid
LOG_TEE("requesting llama server shutdown\n");
ext_server_running = false;
ext_server_thread.join();
delete llama;
llama = NULL;
LOG_TEE("llama server shutdown complete\n");
shutting_down = false;
}
void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
@@ -191,13 +122,8 @@ void llama_server_completion(const char *json_req, ext_server_resp_t *resp) {
resp->id = -1;
resp->msg[0] = '\0';
try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
json data = json::parse(json_req);
resp->id = llama->queue_tasks.get_new_id();
llama->queue_results.add_waiting_task_id(resp->id);
llama->request_completion(resp->id, data, false, false, -1);
resp->id = llama->request_completion(data, false, false, -1);
} catch (std::exception &e) {
snprintf(resp->msg, resp->msg_len, "exception %s", e.what());
} catch (...) {
@@ -215,28 +141,16 @@ void llama_server_completion_next_result(const int task_id,
resp->json_resp = NULL;
std::string result_json;
try {
atomicRecv ar(recv_counter);
task_result result = llama->queue_results.recv(task_id);
task_result result = llama->next_result(task_id);
result_json =
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace);
resp->id = result.id;
resp->stop = result.stop;
resp->error = result.error;
if (result.error) {
LOG_TEE("next result cancel on error\n");
llama->request_cancel(task_id);
LOG_TEE("next result removing waiting tak ID: %d\n", task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} else if (result.stop) {
LOG_TEE("next result cancel on stop\n");
llama->request_cancel(task_id);
LOG_TEE("next result removing waiting task ID: %d\n", task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} else if (shutting_down) {
LOG_TEE("aborting completion due to shutdown %d\n", task_id);
llama->request_cancel(task_id);
llama->queue_results.remove_waiting_task_id(task_id);
resp->stop = true;
}
} catch (std::exception &e) {
resp->error = true;
@@ -267,7 +181,6 @@ void llama_server_completion_cancel(const int task_id, ext_server_resp_t *err) {
err->msg[0] = '\0';
try {
llama->request_cancel(task_id);
llama->queue_results.remove_waiting_task_id(task_id);
} catch (std::exception &e) {
err->id = -1;
snprintf(err->msg, err->msg_len, "exception %s", e.what());
@@ -285,9 +198,6 @@ void llama_server_tokenize(const char *json_req, char **json_resp,
err->id = 0;
err->msg[0] = '\0';
try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req);
std::vector<llama_token> tokens;
if (body.count("content") != 0) {
@@ -321,9 +231,6 @@ void llama_server_detokenize(const char *json_req, char **json_resp,
err->id = 0;
err->msg[0] = '\0';
try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req);
std::string content;
if (body.count("tokens") != 0) {
@@ -351,9 +258,6 @@ void llama_server_embedding(const char *json_req, char **json_resp,
err->id = 0;
err->msg[0] = '\0';
try {
if (shutting_down) {
throw std::runtime_error("server shutting down");
}
const json body = json::parse(json_req);
json prompt;
if (body.count("content") != 0) {
@@ -361,16 +265,13 @@ void llama_server_embedding(const char *json_req, char **json_resp,
} else {
prompt = "";
}
const int task_id = llama->queue_tasks.get_new_id();
llama->queue_results.add_waiting_task_id(task_id);
llama->request_completion(task_id, {{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
atomicRecv ar(recv_counter);
task_result result = llama->queue_results.recv(task_id);
const int task_id = llama->request_completion(
{{"prompt", prompt}, {"n_predict", 0}}, false, true, -1);
task_result result = llama->next_result(task_id);
std::string result_json = result.result_json.dump();
const std::string::size_type size = result_json.size() + 1;
*json_resp = new char[size];
snprintf(*json_resp, size, "%s", result_json.c_str());
llama->queue_results.remove_waiting_task_id(task_id);
} catch (std::exception &e) {
err->id = -1;
snprintf(err->msg, err->msg_len, "exception %s", e.what());

View File

@@ -45,7 +45,6 @@ typedef struct ext_server_params {
bool embedding; // get only sentence embedding
ext_server_lora_adapter_t *lora_adapters;
char *mmproj;
bool verbose_logging; // Enable verbose logging of the server
} ext_server_params_t;
typedef struct ext_server_task_result {

View File

@@ -4,31 +4,37 @@ package llm
#cgo CFLAGS: -I${SRCDIR}/ext_server -I${SRCDIR}/llama.cpp -I${SRCDIR}/llama.cpp/common -I${SRCDIR}/llama.cpp/examples/server
#cgo CFLAGS: -DNDEBUG -DLLAMA_SERVER_LIBRARY=1 -D_XOPEN_SOURCE=600 -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64
#cgo CFLAGS: -Wmissing-noreturn -Wextra -Wcast-qual -Wno-unused-function -Wno-array-bounds
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations
#cgo CPPFLAGS: -Ofast -Wextra -Wno-unused-function -Wno-unused-variable -Wno-deprecated-declarations -Wno-unused-but-set-variable
#cgo darwin CFLAGS: -D_DARWIN_C_SOURCE
#cgo darwin CPPFLAGS: -DGGML_USE_ACCELERATE
#cgo darwin CPPFLAGS: -DGGML_USE_METAL -DGGML_METAL_NDEBUG
#cgo darwin LDFLAGS: -lc++ -framework Accelerate
#cgo darwin LDFLAGS: -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libcommon.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libext_server.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libllama.a
#cgo darwin LDFLAGS: ${SRCDIR}/llama.cpp/build/darwin/metal/lib/libggml_static.a
#cgo linux CFLAGS: -D_GNU_SOURCE
#cgo linux windows CFLAGS: -DGGML_CUDA_DMMV_X=32 -DGGML_CUDA_MMV_Y=1 -DGGML_CUDA_PEER_MAX_BATCH_SIZE=128 -DGGML_USE_CUBLAS
#cgo linux LDFLAGS: -L/usr/local/cuda/targets/x86_64-linux/lib -L/usr/local/cuda/lib64 -L/usr/local/cuda/targets/x86_64-linux/lib/stubs
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/build/linux/cpu/lib/libext_server.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/build/linux/cpu/lib/libcommon.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/build/linux/cpu/lib/libllama.a
#cgo linux LDFLAGS: ${SRCDIR}/llama.cpp/build/linux/cpu/lib/libggml_static.a
#cgo linux LDFLAGS: -lrt -ldl -lstdc++ -lm
#cgo linux windows LDFLAGS: -lpthread
#include <stdlib.h>
#include "dyn_ext_server.h"
#include "ext_server.h"
*/
import "C"
import (
"bytes"
"context"
"encoding/json"
"fmt"
"log/slog"
"os"
"path/filepath"
"runtime"
"log"
"strings"
"sync"
"time"
@@ -37,9 +43,19 @@ import (
"github.com/jmorganca/ollama/api"
)
type dynExtServer struct {
s C.struct_dynamic_llama_server
options api.Options
type extServer interface {
LLM
llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t)
llama_server_start()
llama_server_stop()
llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t)
llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t)
llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t)
llama_server_release_task_result(result *C.ext_server_task_result_t)
llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t)
llama_server_release_json_resp(json_resp **C.char)
}
// Note: current implementation does not support concurrent instantiations
@@ -64,30 +80,11 @@ func extServerResponseToErr(resp C.ext_server_resp_t) error {
return fmt.Errorf(C.GoString(resp.msg))
}
// Note: current implementation does not support concurrent instantiations
var llm *dynExtServer
func newDynExtServer(library, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
func newExtServer(server extServer, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
if !mutex.TryLock() {
slog.Info("concurrent llm servers not yet supported, waiting for prior server to complete")
log.Printf("concurrent llm servers not yet supported, waiting for prior server to complete")
mutex.Lock()
}
updatePath(filepath.Dir(library))
libPath := C.CString(library)
defer C.free(unsafe.Pointer(libPath))
resp := newExtServerResp(512)
defer freeExtServerResp(resp)
var srv C.struct_dynamic_llama_server
C.dyn_init(libPath, &srv, &resp)
if resp.id < 0 {
mutex.Unlock()
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
}
llm = &dynExtServer{
s: srv,
options: opts,
}
slog.Info(fmt.Sprintf("Loading Dynamic llm server: %s", library))
var sparams C.ext_server_params_t
sparams.model = C.CString(model)
@@ -136,36 +133,30 @@ func newDynExtServer(library, model string, adapters, projectors []string, opts
sparams.n_threads = C.uint(opts.NumThread)
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
sparams.verbose_logging = C.bool(true)
} else {
sparams.verbose_logging = C.bool(false)
}
slog.Info("Initializing llama server")
initResp := newExtServerResp(128)
defer freeExtServerResp(initResp)
C.dyn_llama_server_init(llm.s, &sparams, &initResp)
if initResp.id < 0 {
mutex.Unlock()
err := extServerResponseToErr(initResp)
slog.Debug(fmt.Sprintf("failure during initialization: %s", err))
return nil, err
}
slog.Info("Starting llama main loop")
C.dyn_llama_server_start(llm.s)
return llm, nil
}
func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn func(PredictResult)) error {
log.Printf("Initializing internal llama server")
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
if len(predict.Images) > 0 {
slog.Info(fmt.Sprintf("loaded %d images", len(predict.Images)))
server.llama_server_init(&sparams, &resp)
if resp.id < 0 {
return nil, extServerResponseToErr(resp)
}
log.Printf("Starting internal llama main loop")
server.llama_server_start()
return server, nil
}
func predict(ctx context.Context, llm extServer, predict PredictOpts, fn func(PredictResult)) error {
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var imageData []ImageData
if len(predict.Images) > 0 {
for cnt, i := range predict.Images {
imageData = append(imageData, ImageData{Data: i, ID: cnt})
}
}
log.Printf("loaded %d images", len(imageData))
request := map[string]any{
"prompt": predict.Prompt,
"stream": true,
@@ -186,7 +177,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
"penalize_nl": predict.Options.PenalizeNewline,
"seed": predict.Options.Seed,
"stop": predict.Options.Stop,
"image_data": predict.Images,
"image_data": imageData,
"cache_prompt": true,
}
@@ -213,7 +204,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
req := C.CString(buffer.String())
defer C.free(unsafe.Pointer(req))
C.dyn_llama_server_completion(llm.s, req, &resp)
llm.llama_server_completion(req, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
}
@@ -224,7 +215,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
select {
case <-ctx.Done():
// This handles the request cancellation
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
llm.llama_server_completion_cancel(resp.id, &resp)
if resp.id < 0 {
return extServerResponseToErr(resp)
} else {
@@ -232,13 +223,13 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
}
default:
var result C.ext_server_task_result_t
C.dyn_llama_server_completion_next_result(llm.s, resp.id, &result)
llm.llama_server_completion_next_result(resp.id, &result)
json_resp := C.GoString(result.json_resp)
C.dyn_llama_server_release_task_result(llm.s, &result)
llm.llama_server_release_task_result(&result)
var p prediction
if err := json.Unmarshal([]byte(json_resp), &p); err != nil {
C.dyn_llama_server_completion_cancel(llm.s, resp.id, &resp)
llm.llama_server_completion_cancel(resp.id, &resp)
if resp.id < 0 {
return fmt.Errorf("error unmarshaling llm prediction response: %w and cancel %s", err, C.GoString(resp.msg))
} else {
@@ -258,7 +249,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
})
}
if p.Stop || bool(result.stop) {
if p.Stop {
fn(PredictResult{
Done: true,
PromptEvalCount: p.Timings.PromptN,
@@ -279,7 +270,7 @@ func (llm *dynExtServer) Predict(ctx context.Context, predict PredictOpts, fn fu
return fmt.Errorf("max retries exceeded")
}
func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
func encode(llm extServer, ctx context.Context, prompt string) ([]int, error) {
data, err := json.Marshal(TokenizeRequest{Content: prompt})
if err != nil {
return nil, fmt.Errorf("marshaling encode data: %w", err)
@@ -289,11 +280,11 @@ func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, erro
var json_resp *C.char
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
C.dyn_llama_server_tokenize(llm.s, req, &json_resp, &resp)
llm.llama_server_tokenize(req, &json_resp, &resp)
if resp.id < 0 {
return nil, extServerResponseToErr(resp)
}
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
defer llm.llama_server_release_json_resp(&json_resp)
var encoded TokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &encoded); err2 != nil {
@@ -303,7 +294,7 @@ func (llm *dynExtServer) Encode(ctx context.Context, prompt string) ([]int, erro
return encoded.Tokens, err
}
func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
func decode(llm extServer, ctx context.Context, tokens []int) (string, error) {
if len(tokens) == 0 {
return "", nil
}
@@ -317,11 +308,11 @@ func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, erro
var json_resp *C.char
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
C.dyn_llama_server_detokenize(llm.s, req, &json_resp, &resp)
llm.llama_server_detokenize(req, &json_resp, &resp)
if resp.id < 0 {
return "", extServerResponseToErr(resp)
}
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
defer llm.llama_server_release_json_resp(&json_resp)
var decoded DetokenizeResponse
if err2 := json.Unmarshal([]byte(C.GoString(json_resp)), &decoded); err2 != nil {
@@ -331,7 +322,7 @@ func (llm *dynExtServer) Decode(ctx context.Context, tokens []int) (string, erro
return decoded.Content, err
}
func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
func embedding(llm extServer, ctx context.Context, input string) ([]float64, error) {
data, err := json.Marshal(TokenizeRequest{Content: input})
if err != nil {
return nil, fmt.Errorf("error marshaling embed data: %w", err)
@@ -342,11 +333,11 @@ func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64
var json_resp *C.char
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
C.dyn_llama_server_embedding(llm.s, req, &json_resp, &resp)
llm.llama_server_embedding(req, &json_resp, &resp)
if resp.id < 0 {
return nil, extServerResponseToErr(resp)
}
defer C.dyn_llama_server_release_json_resp(llm.s, &json_resp)
defer llm.llama_server_release_json_resp(&json_resp)
var embedding EmbeddingResponse
if err := json.Unmarshal([]byte(C.GoString(json_resp)), &embedding); err != nil {
@@ -356,29 +347,7 @@ func (llm *dynExtServer) Embedding(ctx context.Context, input string) ([]float64
return embedding.Embedding, nil
}
func (llm *dynExtServer) Close() {
C.dyn_llama_server_stop(llm.s)
func close(llm extServer) {
llm.llama_server_stop()
mutex.Unlock()
}
func updatePath(dir string) {
if runtime.GOOS == "windows" {
tmpDir := filepath.Dir(dir)
pathComponents := strings.Split(os.Getenv("PATH"), ";")
i := 0
for _, comp := range pathComponents {
if strings.EqualFold(comp, dir) {
return
}
// Remove any other prior paths to our temp dir
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
pathComponents[i] = comp
i++
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
slog.Info(fmt.Sprintf("Updating PATH to %s", newPath))
os.Setenv("PATH", newPath)
}
// linux and darwin rely on rpath
}

80
llm/ext_server_default.go Normal file
View File

@@ -0,0 +1,80 @@
//go:build !windows
package llm
/*
#include <stdlib.h>
#include "ext_server.h"
*/
import "C"
import (
"context"
"github.com/jmorganca/ollama/api"
)
type llamaExtServer struct {
api.Options
}
func (llm *llamaExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
C.llama_server_init(sparams, err)
}
func (llm *llamaExtServer) llama_server_start() {
C.llama_server_start()
}
func (llm *llamaExtServer) llama_server_stop() {
C.llama_server_stop()
}
func (llm *llamaExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
C.llama_server_completion(json_req, resp)
}
func (llm *llamaExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
C.llama_server_completion_next_result(task_id, resp)
}
func (llm *llamaExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
C.llama_server_completion_cancel(task_id, err)
}
func (llm *llamaExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
C.llama_server_release_task_result(result)
}
func (llm *llamaExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_tokenize(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_detokenize(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.llama_server_embedding(json_req, json_resp, err)
}
func (llm *llamaExtServer) llama_server_release_json_resp(json_resp **C.char) {
C.llama_server_release_json_resp(json_resp)
}
func newDefaultExtServer(model string, adapters, projectors []string, opts api.Options) (extServer, error) {
server := &llamaExtServer{opts}
return newExtServer(server, model, adapters, projectors, opts)
}
func (llm *llamaExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(ctx, llm, pred, fn)
}
func (llm *llamaExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
return encode(llm, ctx, prompt)
}
func (llm *llamaExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
return decode(llm, ctx, tokens)
}
func (llm *llamaExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
return embedding(llm, ctx, input)
}
func (llm *llamaExtServer) Close() {
close(llm)
}

12
llm/ext_server_windows.go Normal file
View File

@@ -0,0 +1,12 @@
package llm
import (
"github.com/jmorganca/ollama/api"
)
func newDefaultExtServer(model string, adapters, projectors []string, opts api.Options) (extServer, error) {
// On windows we always load the llama.cpp libraries dynamically to avoid startup DLL dependencies
// This ensures we can update the PATH at runtime to get everything loaded
return newDynamicShimExtServer(AvailableShims["cpu"], model, adapters, projectors, opts)
}

View File

@@ -1,46 +1,14 @@
# common logic accross linux and darwin
init_vars() {
case "${GOARCH}" in
"amd64")
ARCH="x86_64"
;;
"arm64")
ARCH="arm64"
;;
*)
ARCH=$(uname -m | sed -e "s/aarch64/arm64/g")
esac
LLAMACPP_DIR=../llama.cpp
CMAKE_DEFS=""
CMAKE_TARGETS="--target ext_server"
CMAKE_TARGETS="--target ggml --target ggml_static --target llama --target build_info --target common --target ext_server --target llava_static"
if echo "${CGO_CFLAGS}" | grep -- '-g' >/dev/null; then
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on ${CMAKE_DEFS}"
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=RelWithDebInfo -DCMAKE_VERBOSE_MAKEFILE=on -DLLAMA_GPROF=on -DLLAMA_SERVER_VERBOSE=on"
else
# TODO - add additional optimization flags...
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off ${CMAKE_DEFS}"
fi
case $(uname -s) in
"Darwin")
LIB_EXT="dylib"
WHOLE_ARCHIVE="-Wl,-force_load"
NO_WHOLE_ARCHIVE=""
GCC_ARCH="-arch ${ARCH}"
;;
"Linux")
LIB_EXT="so"
WHOLE_ARCHIVE="-Wl,--whole-archive"
NO_WHOLE_ARCHIVE="-Wl,--no-whole-archive"
# Cross compiling not supported on linux - Use docker
GCC_ARCH=""
;;
*)
;;
esac
if [ -z "${CMAKE_CUDA_ARCHITECTURES}" ] ; then
CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
CMAKE_DEFS="-DCMAKE_BUILD_TYPE=Release -DLLAMA_SERVER_VERBOSE=off"
fi
}
@@ -64,19 +32,6 @@ apply_patches() {
if ! grep ollama ${LLAMACPP_DIR}/examples/server/CMakeLists.txt; then
echo 'include (../../../ext_server/CMakeLists.txt) # ollama' >>${LLAMACPP_DIR}/examples/server/CMakeLists.txt
fi
if [ -n "$(ls -A ../patches/*.diff)" ]; then
# apply temporary patches until fix is upstream
for patch in ../patches/*.diff; do
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
(cd ${LLAMACPP_DIR}; git checkout ${file})
done
done
for patch in ../patches/*.diff; do
(cd ${LLAMACPP_DIR} && git apply ${patch})
done
fi
# Avoid duplicate main symbols when we link into the cgo binary
sed -e 's/int main(/int __main(/g' <${LLAMACPP_DIR}/examples/server/server.cpp >${LLAMACPP_DIR}/examples/server/server.cpp.tmp &&
mv ${LLAMACPP_DIR}/examples/server/server.cpp.tmp ${LLAMACPP_DIR}/examples/server/server.cpp
@@ -85,41 +40,18 @@ apply_patches() {
build() {
cmake -S ${LLAMACPP_DIR} -B ${BUILD_DIR} ${CMAKE_DEFS}
cmake --build ${BUILD_DIR} ${CMAKE_TARGETS} -j8
mkdir -p ${BUILD_DIR}/lib/
g++ -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.${LIB_EXT} \
${GCC_ARCH} \
${WHOLE_ARCHIVE} ${BUILD_DIR}/examples/server/libext_server.a ${NO_WHOLE_ARCHIVE} \
${BUILD_DIR}/common/libcommon.a \
${BUILD_DIR}/libllama.a \
-Wl,-rpath,\$ORIGIN \
-lpthread -ldl -lm \
${EXTRA_LIBS}
}
compress_libs() {
echo "Compressing payloads to reduce overall binary size..."
pids=""
rm -rf ${BUILD_DIR}/lib/*.${LIB_EXT}*.gz
for lib in ${BUILD_DIR}/lib/*.${LIB_EXT}* ; do
gzip --best -f ${lib} &
pids+=" $!"
done
echo
for pid in ${pids}; do
wait $pid
done
echo "Finished compression"
install() {
rm -rf ${BUILD_DIR}/lib
mkdir -p ${BUILD_DIR}/lib
cp ${BUILD_DIR}/examples/server/libext_server.a ${BUILD_DIR}/lib
cp ${BUILD_DIR}/common/libcommon.a ${BUILD_DIR}/lib
cp ${BUILD_DIR}/libllama.a ${BUILD_DIR}/lib
cp ${BUILD_DIR}/libggml_static.a ${BUILD_DIR}/lib
}
# Keep the local tree clean after we're done with the build
cleanup() {
(cd ${LLAMACPP_DIR}/examples/server/ && git checkout CMakeLists.txt server.cpp)
if [ -n "$(ls -A ../patches/*.diff)" ]; then
for patch in ../patches/*.diff; do
for file in $(grep "^+++ " ${patch} | cut -f2 -d' ' | cut -f2- -d/); do
(cd ${LLAMACPP_DIR}; git checkout ${file})
done
done
fi
}

View File

@@ -9,63 +9,14 @@ set -o pipefail
echo "Starting darwin generate script"
source $(dirname $0)/gen_common.sh
init_vars
git_module_setup
apply_patches
sign() {
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp --deep --options=runtime --sign "$APPLE_IDENTITY" --identifier ai.ollama.ollama $1
fi
}
COMMON_DARWIN_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin"
CMAKE_DEFS="-DCMAKE_OSX_DEPLOYMENT_TARGET=11.0 -DCMAKE_SYSTEM_NAME=Darwin -DLLAMA_ACCELERATE=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/metal"
case "${GOARCH}" in
"amd64")
COMMON_CPU_DEFS="${COMMON_DARWIN_DEFS} -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=off -DLLAMA_NATIVE=off"
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu"
echo "Building LCD CPU"
build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu/lib/libext_server.dylib
compress_libs
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx/lib/libext_server.dylib
compress_libs
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_ACCELERATE=on -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation"
build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/cpu_avx2/lib/libext_server.dylib
compress_libs
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=x86_64 -DCMAKE_OSX_ARCHITECTURES=x86_64 -DLLAMA_METAL=off -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
;;
"arm64")
CMAKE_DEFS="${COMMON_DARWIN_DEFS} -DLLAMA_ACCELERATE=on -DCMAKE_SYSTEM_PROCESSOR=${ARCH} -DCMAKE_OSX_ARCHITECTURES=${ARCH} -DLLAMA_METAL=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/darwin/${ARCH}/metal"
EXTRA_LIBS="${EXTRA_LIBS} -framework Accelerate -framework Foundation -framework Metal -framework MetalKit -framework MetalPerformanceShaders"
build
sign ${LLAMACPP_DIR}/build/darwin/${ARCH}/metal/lib/libext_server.dylib
compress_libs
CMAKE_DEFS="-DCMAKE_SYSTEM_PROCESSOR=arm64 -DCMAKE_OSX_ARCHITECTURES=arm64 -DLLAMA_METAL=on ${CMAKE_DEFS}"
;;
*)
echo "GOARCH must be set"
@@ -74,4 +25,8 @@ case "${GOARCH}" in
;;
esac
git_module_setup
apply_patches
build
install
cleanup

View File

@@ -2,25 +2,24 @@
# This script is intended to run inside the go generate
# working directory must be llm/generate/
# First we build one or more CPU based LLM libraries
# First we build our default built-in library which will be linked into the CGO
# binary as a normal dependency. This default build is CPU based.
#
# Then if we detect CUDA, we build a CUDA dynamic library, and carry the required
# library dependencies
# Then we build a CUDA dynamic library (although statically linked with the CUDA
# library dependencies for maximum portability)
#
# Then if we detect ROCm, we build a dynamically loaded ROCm lib. The ROCM
# libraries are quite large, and also dynamically load data files at runtime
# which in turn are large, so we don't attempt to cary them as payload
# Then if we detect ROCm, we build a dynamically loaded ROCm lib. ROCm is particularly
# important to be a dynamic lib even if it's the only GPU library detected because
# we can't redistribute the objectfiles but must rely on dynamic libraries at
# runtime, which could lead the server not to start if not present.
set -ex
set -o pipefail
# See https://llvm.org/docs/AMDGPUUsage.html#processors for reference
amdGPUs() {
if [ -n "${AMDGPU_TARGETS}" ]; then
echo "${AMDGPU_TARGETS}"
return
fi
GPU_LIST=(
"gfx803"
"gfx900"
"gfx906:xnack-"
"gfx908:xnack-"
@@ -40,13 +39,8 @@ amdGPUs() {
}
echo "Starting linux generate script"
if [ -z "${CUDACXX}" ]; then
if [ -x /usr/local/cuda/bin/nvcc ]; then
export CUDACXX=/usr/local/cuda/bin/nvcc
else
# Try the default location in case it exists
export CUDACXX=$(command -v nvcc)
fi
if [ -z "${CUDACXX}" -a -x /usr/local/cuda/bin/nvcc ]; then
export CUDACXX=/usr/local/cuda/bin/nvcc
fi
COMMON_CMAKE_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off"
source $(dirname $0)/gen_common.sh
@@ -54,115 +48,38 @@ init_vars
git_module_setup
apply_patches
if [ -z "${OLLAMA_SKIP_CPU_GENERATE}" ]; then
# Users building from source can tune the exact flags we pass to cmake for configuring
# llama.cpp, and we'll build only 1 CPU variant in that case as the default.
if [ -n "${OLLAMA_CUSTOM_CPU_DEFS}" ]; then
echo "OLLAMA_CUSTOM_CPU_DEFS=\"${OLLAMA_CUSTOM_CPU_DEFS}\""
CMAKE_DEFS="${OLLAMA_CUSTOM_CPU_DEFS} -DCMAKE_POSITION_INDEPENDENT_CODE=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
echo "Building custom CPU"
build
compress_libs
else
# Darwin Rosetta x86 emulation does NOT support AVX, AVX2, AVX512
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# Note: the following seem to yield slower results than AVX2 - ymmv
# -DLLAMA_AVX512 -- 2017 Intel Skylake and High End DeskTop (HEDT)
# -DLLAMA_AVX512_VBMI -- 2018 Intel Cannon Lake
# -DLLAMA_AVX512_VNNI -- 2021 Intel Alder Lake
#
# CPU first for the default library
#
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/cpu"
COMMON_CPU_DEFS="-DCMAKE_POSITION_INDEPENDENT_CODE=on -DLLAMA_NATIVE=off"
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu" ]; then
#
# CPU first for the default library, set up as lowest common denominator for maximum compatibility (including Rosetta)
#
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=off -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu"
echo "Building LCD CPU"
build
compress_libs
fi
build
install
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx" ]; then
#
# ~2011 CPU Dynamic library with more capabilities turned on to optimize performance
# Approximately 400% faster than LCD on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=off -DLLAMA_AVX512=off -DLLAMA_FMA=off -DLLAMA_F16C=off ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx"
echo "Building AVX CPU"
build
compress_libs
fi
# Placeholder to keep go embed happy until we start building dynamic CPU lib variants
touch ${BUILD_DIR}/lib/dummy.so
if [ -z "${OLLAMA_CPU_TARGET}" -o "${OLLAMA_CPU_TARGET}" = "cpu_avx2" ]; then
#
# ~2013 CPU Dynamic library
# Approximately 10% faster than AVX on same CPU
#
init_vars
CMAKE_DEFS="${COMMON_CPU_DEFS} -DLLAMA_AVX=on -DLLAMA_AVX2=on -DLLAMA_AVX512=off -DLLAMA_FMA=on -DLLAMA_F16C=on ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cpu_avx2"
echo "Building AVX2 CPU"
build
compress_libs
fi
fi
else
echo "Skipping CPU generation step as requested"
fi
# If needed, look for the default CUDA toolkit location
if [ -z "${CUDA_LIB_DIR}" ] && [ -d /usr/local/cuda/lib64 ]; then
CUDA_LIB_DIR=/usr/local/cuda/lib64
fi
# If needed, look for CUDA on Arch Linux
if [ -z "${CUDA_LIB_DIR}" ] && [ -d /opt/cuda/targets/x86_64-linux/lib ]; then
CUDA_LIB_DIR=/opt/cuda/targets/x86_64-linux/lib
fi
# Allow override in case libcudart is in the wrong place
if [ -z "${CUDART_LIB_DIR}" ]; then
CUDART_LIB_DIR="${CUDA_LIB_DIR}"
fi
if [ -d "${CUDA_LIB_DIR}" ]; then
if [ -d /usr/local/cuda/lib64/ ]; then
echo "CUDA libraries detected - building dynamic CUDA library"
init_vars
CUDA_MAJOR=$(ls "${CUDA_LIB_DIR}"/libcudart.so.* | head -1 | cut -f3 -d. || true)
if [ -n "${CUDA_MAJOR}" ]; then
CUDA_VARIANT=_v${CUDA_MAJOR}
fi
CMAKE_DEFS="-DLLAMA_CUBLAS=on -DLLAMA_CUDA_FORCE_MMQ=on -DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHITECTURES} ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/cuda${CUDA_VARIANT}"
EXTRA_LIBS="-L${CUDA_LIB_DIR} -lcudart -lcublas -lcublasLt -lcuda"
CMAKE_DEFS="-DLLAMA_CUBLAS=on ${COMMON_CMAKE_DEFS} ${CMAKE_DEFS}"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/cuda"
CUDA_LIB_DIR=/usr/local/cuda/lib64
build
# Cary the CUDA libs as payloads to help reduce dependency burden on users
#
# TODO - in the future we may shift to packaging these separately and conditionally
# downloading them in the install script.
DEPS="$(ldd ${BUILD_DIR}/lib/libext_server.so )"
for lib in libcudart.so libcublas.so libcublasLt.so ; do
DEP=$(echo "${DEPS}" | grep ${lib} | cut -f1 -d' ' | xargs || true)
if [ -n "${DEP}" -a -e "${CUDA_LIB_DIR}/${DEP}" ]; then
cp "${CUDA_LIB_DIR}/${DEP}" "${BUILD_DIR}/lib/"
elif [ -e "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" ]; then
cp "${CUDA_LIB_DIR}/${lib}.${CUDA_MAJOR}" "${BUILD_DIR}/lib/"
elif [ -e "${CUDART_LIB_DIR}/${lib}" ]; then
cp -d ${CUDART_LIB_DIR}/${lib}* "${BUILD_DIR}/lib/"
else
cp -d "${CUDA_LIB_DIR}/${lib}*" "${BUILD_DIR}/lib/"
fi
done
compress_libs
install
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
-Wl,--whole-archive \
${BUILD_DIR}/lib/libext_server.a \
${BUILD_DIR}/lib/libcommon.a \
${BUILD_DIR}/lib/libllama.a \
-Wl,--no-whole-archive \
${CUDA_LIB_DIR}/libcudart_static.a \
${CUDA_LIB_DIR}/libcublas_static.a \
${CUDA_LIB_DIR}/libcublasLt_static.a \
${CUDA_LIB_DIR}/libcudadevrt.a \
${CUDA_LIB_DIR}/libculibos.a \
-lrt -lpthread -ldl -lstdc++ -lm
fi
if [ -z "${ROCM_PATH}" ]; then
@@ -179,18 +96,21 @@ fi
if [ -d "${ROCM_PATH}" ]; then
echo "ROCm libraries detected - building dynamic ROCm library"
if [ -f ${ROCM_PATH}/lib/librocm_smi64.so.? ]; then
ROCM_VARIANT=_v$(ls ${ROCM_PATH}/lib/librocm_smi64.so.? | cut -f3 -d. || true)
fi
init_vars
CMAKE_DEFS="${COMMON_CMAKE_DEFS} ${CMAKE_DEFS} -DLLAMA_HIPBLAS=on -DCMAKE_C_COMPILER=$ROCM_PATH/llvm/bin/clang -DCMAKE_CXX_COMPILER=$ROCM_PATH/llvm/bin/clang++ -DAMDGPU_TARGETS=$(amdGPUs) -DGPU_TARGETS=$(amdGPUs)"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/${ARCH}/rocm${ROCM_VARIANT}"
EXTRA_LIBS="-L${ROCM_PATH}/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ -Wl,-rpath,${ROCM_PATH}/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ -lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu"
BUILD_DIR="${LLAMACPP_DIR}/build/linux/rocm"
build
# Note: the ROCM libs and runtime library files are too large to embed, so we depend on
# them being present at runtime on the host
compress_libs
install
gcc -fPIC -g -shared -o ${BUILD_DIR}/lib/libext_server.so \
-Wl,--whole-archive \
${BUILD_DIR}/lib/libext_server.a \
${BUILD_DIR}/lib/libcommon.a \
${BUILD_DIR}/lib/libllama.a \
-Wl,--no-whole-archive \
-lrt -lpthread -ldl -lstdc++ -lm \
-L/opt/rocm/lib -L/opt/amdgpu/lib/x86_64-linux-gnu/ \
-Wl,-rpath,/opt/rocm/lib,-rpath,/opt/amdgpu/lib/x86_64-linux-gnu/ \
-lhipblas -lrocblas -lamdhip64 -lrocsolver -lamd_comgr -lhsa-runtime64 -lrocsparse -ldrm -ldrm_amdgpu
fi
cleanup

View File

@@ -4,9 +4,8 @@ $ErrorActionPreference = "Stop"
function init_vars {
$script:llamacppDir = "../llama.cpp"
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-A","x64")
$script:cmakeTargets = @("ext_server")
$script:ARCH = "amd64" # arm not yet supported.
$script:cmakeDefs = @("-DBUILD_SHARED_LIBS=on", "-DLLAMA_NATIVE=off", "-DLLAMA_F16C=off", "-DLLAMA_FMA=off", "-DLLAMA_AVX512=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX=on", "-A","x64")
$script:cmakeTargets = @("ggml", "ggml_static", "llama", "build_info", "common", "ext_server_shared", "llava_static")
if ($env:CGO_CFLAGS -contains "-g") {
$script:cmakeDefs += @("-DCMAKE_VERBOSE_MAKEFILE=on", "-DLLAMA_SERVER_VERBOSE=on")
$script:config = "RelWithDebInfo"
@@ -14,22 +13,6 @@ function init_vars {
$script:cmakeDefs += @("-DLLAMA_SERVER_VERBOSE=off")
$script:config = "Release"
}
# Try to find the CUDA dir
if ($env:CUDA_LIB_DIR -eq $null) {
$d=(get-command -ea 'silentlycontinue' nvcc).path
if ($d -ne $null) {
$script:CUDA_LIB_DIR=($d| split-path -parent)
}
} else {
$script:CUDA_LIB_DIR=$env:CUDA_LIB_DIR
}
$script:GZIP=(get-command -ea 'silentlycontinue' gzip).path
$script:DUMPBIN=(get-command -ea 'silentlycontinue' dumpbin).path
if ($null -eq $env:CMAKE_CUDA_ARCHITECTURES) {
$script:CMAKE_CUDA_ARCHITECTURES="50;52;61;70;75;80"
} else {
$script:CMAKE_CUDA_ARCHITECTURES=$env:CMAKE_CUDA_ARCHITECTURES
}
}
function git_module_setup {
@@ -45,29 +28,6 @@ function apply_patches {
if (!(Select-String -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Pattern 'ollama')) {
Add-Content -Path "${script:llamacppDir}/examples/server/CMakeLists.txt" -Value 'include (../../../ext_server/CMakeLists.txt) # ollama'
}
# Apply temporary patches until fix is upstream
$patches = Get-ChildItem "../patches/*.diff"
foreach ($patch in $patches) {
# Extract file paths from the patch file
$filePaths = Get-Content $patch.FullName | Where-Object { $_ -match '^\+\+\+ ' } | ForEach-Object {
$parts = $_ -split ' '
($parts[1] -split '/', 2)[1]
}
# Checkout each file
foreach ($file in $filePaths) {
Set-Location -Path ${script:llamacppDir}
git checkout $file
}
}
# Apply each patch
foreach ($patch in $patches) {
Set-Location -Path ${script:llamacppDir}
git apply $patch.FullName
}
# Avoid duplicate main symbols when we link into the cgo binary
$content = Get-Content -Path "${script:llamacppDir}/examples/server/server.cpp"
$content = $content -replace 'int main\(', 'int __main('
@@ -87,25 +47,11 @@ function build {
function install {
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
md "${script:buildDir}/lib" -ea 0 > $null
cp "${script:buildDir}/bin/${script:config}/ext_server.dll" "${script:buildDir}/lib"
cp "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" "${script:buildDir}/lib"
cp "${script:buildDir}/bin/${script:config}/llama.dll" "${script:buildDir}/lib"
# Display the dll dependencies in the build log
if ($script:DUMPBIN -ne $null) {
& "$script:DUMPBIN" /dependents "${script:buildDir}/bin/${script:config}/ext_server.dll" | select-string ".dll"
}
}
function compress_libs {
if ($script:GZIP -eq $null) {
write-host "gzip not installed, not compressing files"
return
}
write-host "Compressing dlls..."
$libs = dir "${script:buildDir}/lib/*.dll"
foreach ($file in $libs) {
& "$script:GZIP" --best -f $file
}
dumpbin /dependents "${script:buildDir}/bin/${script:config}/ext_server_shared.dll" | select-string ".dll"
}
function cleanup {
@@ -117,55 +63,21 @@ init_vars
git_module_setup
apply_patches
# -DLLAMA_AVX -- 2011 Intel Sandy Bridge & AMD Bulldozer
# -DLLAMA_F16C -- 2012 Intel Ivy Bridge & AMD 2011 Bulldozer (No significant improvement over just AVX)
# -DLLAMA_AVX2 -- 2013 Intel Haswell & 2015 AMD Excavator / 2017 AMD Zen
# -DLLAMA_FMA (FMA3) -- 2013 Intel Haswell & 2012 AMD Piledriver
# first build CPU based
$script:buildDir="${script:llamacppDir}/build/windows/cpu"
$script:commonCpuDefs = @("-DCMAKE_POSITION_INDEPENDENT_CODE=on", "-DLLAMA_NATIVE=off")
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=off", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu"
write-host "Building LCD CPU"
build
install
compress_libs
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=off", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=off", "-DLLAMA_F16C=off") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx"
write-host "Building AVX CPU"
# Then build cuda as a dynamically loaded library
init_vars
$script:buildDir="${script:llamacppDir}/build/windows/cuda"
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON")
build
install
compress_libs
$script:cmakeDefs = $script:commonCpuDefs + @("-DLLAMA_AVX=on", "-DLLAMA_AVX2=on", "-DLLAMA_AVX512=off", "-DLLAMA_FMA=on", "-DLLAMA_F16C=on") + $script:cmakeDefs
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cpu_avx2"
write-host "Building AVX2 CPU"
build
install
compress_libs
if ($null -ne $script:CUDA_LIB_DIR) {
# Then build cuda as a dynamically loaded library
$nvcc = (get-command -ea 'silentlycontinue' nvcc)
if ($null -ne $nvcc) {
$script:CUDA_VERSION=(get-item ($nvcc | split-path | split-path)).Basename
}
if ($null -ne $script:CUDA_VERSION) {
$script:CUDA_VARIANT="_"+$script:CUDA_VERSION
}
init_vars
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/cuda$script:CUDA_VARIANT"
$script:cmakeDefs += @("-DLLAMA_CUBLAS=ON", "-DLLAMA_AVX=on", "-DCMAKE_CUDA_ARCHITECTURES=${script:CMAKE_CUDA_ARCHITECTURES}")
build
install
cp "${script:CUDA_LIB_DIR}/cudart64_*.dll" "${script:buildDir}/lib"
cp "${script:CUDA_LIB_DIR}/cublas64_*.dll" "${script:buildDir}/lib"
cp "${script:CUDA_LIB_DIR}/cublasLt64_*.dll" "${script:buildDir}/lib"
compress_libs
}
# TODO - actually implement ROCm support on windows
$script:buildDir="${script:llamacppDir}/build/windows/${script:ARCH}/rocm"
$script:buildDir="${script:llamacppDir}/build/windows/rocm"
rm -ea 0 -recurse -force -path "${script:buildDir}/lib"
md "${script:buildDir}/lib" -ea 0 > $null

View File

@@ -83,7 +83,6 @@ type model interface {
NumEmbed() uint32
NumHead() uint32
NumHeadKv() uint32
NumCtx() uint32
}
type container interface {
@@ -99,9 +98,9 @@ func (c *containerLORA) Name() string {
return "ggla"
}
func (c *containerLORA) Decode(rso *readSeekOffset) (model, error) {
func (c *containerLORA) Decode(ro *readSeekOffset) (model, error) {
var version uint32
binary.Read(rso, binary.LittleEndian, &version)
binary.Read(ro, binary.LittleEndian, &version)
switch version {
case 1:
@@ -112,7 +111,7 @@ func (c *containerLORA) Decode(rso *readSeekOffset) (model, error) {
c.version = version
// remaining file contents aren't decoded
rso.Seek(0, io.SeekEnd)
ro.Seek(0, io.SeekEnd)
return nil, nil
}

View File

@@ -69,65 +69,12 @@ type tensor struct {
name string
kind uint32
offset uint64
size uint64
// shape is the number of elements in each dimension
shape [4]uint64
}
func (t tensor) blockSize() uint64 {
switch {
case t.kind < 2:
return 1
case t.kind < 10:
return 32
default:
return 256
}
}
func (t tensor) typeSize() uint64 {
blockSize := t.blockSize()
switch t.kind {
case 0: // FP32
return 4
case 1: // FP16
return 2
case 2: // Q4_0
return 2 + blockSize/2
case 3: // Q4_1
return 2 + 2 + blockSize/2
case 6: // Q5_0
return 2 + 4 + blockSize/2
case 7: // Q5_1
return 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
return 2 + blockSize
case 9: // Q8_1
return 4 + 4 + blockSize
case 10: // Q2_K
return blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
return blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
return 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
return 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
return blockSize/2 + blockSize/4 + blockSize/16 + 2
default:
return 0
}
}
func (t tensor) parameters() uint64 {
return t.shape[0] * t.shape[1] * t.shape[2] * t.shape[3]
}
func (t tensor) size() uint64 {
return t.parameters() * t.typeSize() / t.blockSize()
}
type ggufModel struct {
*containerGGUF
@@ -254,15 +201,61 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
shape[i] = llm.readU64(rso)
}
tensor := tensor{
name: name,
kind: llm.readU32(rso),
offset: llm.readU64(rso),
shape: shape,
kind := llm.readU32(rso)
offset := llm.readU64(rso)
var blockSize uint64
switch {
case kind < 2:
blockSize = 1
case kind < 10:
blockSize = 32
default:
blockSize = 256
}
llm.tensors = append(llm.tensors, tensor)
llm.parameters += tensor.parameters()
var typeSize uint64
switch kind {
case 0: // FP32
typeSize = 4
case 1: // FP16
typeSize = 2
case 2: // Q4_0
typeSize = 2 + blockSize/2
case 3: // Q4_1
typeSize = 2 + 2 + blockSize/2
case 6: // Q5_0
typeSize = 2 + 4 + blockSize/2
case 7: // Q5_1
typeSize = 2 + 2 + 4 + blockSize/2
case 8: // Q8_0
typeSize = 2 + blockSize
case 9: // Q8_1
typeSize = 4 + 4 + blockSize
case 10: // Q2_K
typeSize = blockSize/16 + blockSize/4 + 2 + 2
case 11: // Q3_K
typeSize = blockSize/8 + blockSize/4 + 12 + 2
case 12: // Q4_K
typeSize = 2 + 2 + 12 + blockSize/2
case 13: // Q5_K
typeSize = 2 + 2 + 12 + blockSize/8 + blockSize/2
case 14: // Q6_K
typeSize = blockSize/2 + blockSize/4 + blockSize/16 + 2
}
parameters := shape[0] * shape[1] * shape[2] * shape[3]
size := parameters * typeSize / blockSize
llm.tensors = append(llm.tensors, tensor{
name: name,
kind: kind,
offset: offset,
size: size,
shape: shape,
})
llm.parameters += parameters
}
alignment, ok := llm.kv["general.alignment"].(uint32)
@@ -272,7 +265,7 @@ func (llm *ggufModel) Decode(rso *readSeekOffset) error {
rso.Seek(int64(alignment)-rso.offset%int64(alignment), io.SeekCurrent)
for _, tensor := range llm.tensors {
padded := (int64(tensor.size()) + int64(alignment) - 1) & ^(int64(alignment) - 1)
padded := (int64(tensor.size) + int64(alignment) - 1) & ^(int64(alignment) - 1)
rso.Seek(padded, io.SeekCurrent)
}
@@ -315,15 +308,6 @@ func (llm *ggufModel) NumHeadKv() uint32 {
return value.(uint32)
}
func (llm *ggufModel) NumCtx() uint32 {
value, exists := llm.kv[fmt.Sprintf("%s.context_length", llm.ModelFamily())]
if !exists {
return 0
}
return value.(uint32)
}
func (llm *ggufModel) NumGQA() uint32 {
numHeadKv := llm.NumHeadKv()
if numHeadKv == 0 {

View File

@@ -1,11 +1,17 @@
package llm
import (
"bytes"
"context"
_ "embed"
"errors"
"fmt"
"os"
"os/exec"
"time"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/format"
)
const jsonGrammar = `
@@ -36,12 +42,51 @@ number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)?
`
type Running struct {
Port int
Cmd *exec.Cmd
Cancel context.CancelFunc
*StatusWriter // captures error messages from the llama runner process
}
type ImageData struct {
Data []byte `json:"data"`
ID int `json:"id"`
}
var payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
var (
errNvidiaSMI = errors.New("warning: gpu support may not be enabled, check that you have installed GPU drivers: nvidia-smi command failed")
errAvailableVRAM = errors.New("not enough VRAM available, falling back to CPU only")
payloadMissing = fmt.Errorf("expected dynamic library payloads not included in this build of ollama")
)
// StatusWriter is a writer that captures error messages from the llama runner process
type StatusWriter struct {
ErrCh chan error
LastErrMsg string
}
func NewStatusWriter() *StatusWriter {
return &StatusWriter{
ErrCh: make(chan error, 1),
}
}
func (w *StatusWriter) Write(b []byte) (int, error) {
var errMsg string
if _, after, ok := bytes.Cut(b, []byte("error:")); ok {
errMsg = string(bytes.TrimSpace(after))
} else if _, after, ok := bytes.Cut(b, []byte("CUDA error")); ok {
errMsg = string(bytes.TrimSpace(after))
}
if errMsg != "" {
w.LastErrMsg = errMsg
w.ErrCh <- fmt.Errorf("llama runner: %s", errMsg)
}
return os.Stderr.Write(b)
}
type prediction struct {
Content string `json:"content"`
@@ -57,12 +102,14 @@ type prediction struct {
}
}
const maxBufferSize = 512 * format.KiloByte
const maxRetries = 3
const retryDelay = 1 * time.Second
type PredictOpts struct {
Prompt string
Format string
Images []ImageData
Images []api.ImageData
Options api.Options
}

View File

@@ -3,7 +3,7 @@ package llm
import (
"context"
"fmt"
"log/slog"
"log"
"os"
"runtime"
@@ -19,6 +19,8 @@ type LLM interface {
Close()
}
var AvailableShims = map[string]string{}
func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
if _, err := os.Stat(model); err != nil {
return nil, err
@@ -35,92 +37,95 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options)
return nil, err
}
if opts.NumCtx > int(ggml.NumCtx()) {
slog.Warn(fmt.Sprintf("requested context length is greater than model's max context length (%d > %d), using %d instead", opts.NumCtx, ggml.NumCtx(), ggml.NumCtx()))
opts.NumCtx = int(ggml.NumCtx())
}
if opts.NumCtx < 4 {
opts.NumCtx = 4
}
vram, _ := gpu.CheckVRAM()
size := ggml.Size
fmt.Println("size", ggml.Size)
fmt.Println("filetype", ggml.FileType())
fmt.Println("architecture", ggml.ModelFamily())
fmt.Println("type", ggml.ModelType())
fmt.Println("name", ggml.Name())
fmt.Println("embd", ggml.NumEmbed())
fmt.Println("head", ggml.NumHead())
fmt.Println("head_kv", ggml.NumHeadKv())
fmt.Println("gqa", ggml.NumGQA())
available, _ := gpu.CheckVRAM()
// For now assume filesize = model size
// TODO: use actual model size
requiredModel := ggml.Size
// fp16 k,v matrices require = n_ctx * n_layer * n_embd / n_head * n_head_kv * 2 bytes each * 2 key and value
kv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
requiredKv := 2 * 2 * int64(opts.NumCtx) * int64(ggml.NumLayers()) * int64(ggml.NumEmbed()) * int64(ggml.NumHeadKv()) / int64(ggml.NumHead())
// this amount is the overhead + tensors in memory
// TODO: get this from the llama.cpp's graph calculations instead of
// TODO: get this from the llama.cpp's graph calcluations instead of
// estimating it's 1/6 * kv_cache_size * num_gqa
graph := int64(ggml.NumGQA()) * kv / 6
requiredAlloc := int64(ggml.NumGQA()) * requiredKv / 6
requiredTotal := requiredModel + requiredKv + requiredAlloc
log.Println("system memory bytes:", available)
log.Println("required model bytes:", requiredModel)
log.Println("required kv bytes:", requiredKv)
log.Println("required alloc bytes:", requiredAlloc)
log.Println("required total bytes:", requiredTotal)
info := gpu.GetGPUInfo()
switch runtime.GOOS {
case "darwin":
if opts.NumGPU == 0 {
break
}
library := info.Library
if size+kv+graph > vram {
slog.Info("not enough vram available, falling back to CPU only")
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0
break
}
// TODO: implement layer splitting on macOS
opts.NumGPU = 999
default:
if info.Library == "cpu" {
slog.Info("GPU not available, falling back to CPU")
opts.NumGPU = 0
break
}
// don't use GPU at all if no layers are loaded
if opts.NumGPU == 0 {
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
break
}
// user-defined GPU count
if opts.NumGPU != -1 {
break
}
// the "main" GPU needs the most memory and determines the limit
// of how many layers can be loaded. It needs to fit:
// 1. the full compute graph allocation for all devices (graph)
// 2. the proportional kv cache for all devices (kv * % layers)
// 3. the proportional model (size * % layers / # devices)
// This estimates the number of layers
maxlayers := int64(ggml.NumLayers()) + 1
devices := int64(info.DeviceCount)
avg := vram / devices
layers := maxlayers * (avg - graph) / (kv + size/devices)
if layers > maxlayers {
layers = maxlayers
}
// 1 + 2 must fit on the main gpu
min := graph + kv*layers/maxlayers
if layers <= 0 || min > avg {
slog.Info("not enough vram available, falling back to CPU only")
info.Library = "cpu"
info.Variant = gpu.GetCPUVariant()
opts.NumGPU = 0
break
}
opts.NumGPU = int(layers)
if opts.NumGPU == -1 {
// default to offloading all layers
opts.NumGPU = int(ggml.NumLayers()) + 1
}
// decide how many layers to put on the GPU
if opts.NumGPU > 0 {
switch runtime.GOOS {
case "darwin":
if requiredTotal > available {
log.Println("not enough vram available, falling back to CPU only")
opts.NumGPU = 0
}
default:
if library == "cpu" || library == "default" {
opts.NumGPU = 0
break
}
// no offloading required
if requiredTotal <= available {
break
}
// requiredAlloc is always loaded for the CUDA runner, so don't load it if it won't fit
if requiredAlloc > available {
log.Printf("not enough vram available, falling back to CPU only")
library = "cpu"
opts.NumGPU = 0
break
}
available -= requiredAlloc
// fill remaining vram with layers
log.Println("splitting", available, "of available memory bytes into layers")
bytesPerLayer := int64((requiredModel + requiredKv) / int64(ggml.NumLayers()))
log.Println("bytes per layer:", bytesPerLayer)
layers := available / bytesPerLayer
log.Println("total required with split:", requiredAlloc+(layers*bytesPerLayer))
if layers < int64(opts.NumGPU) {
opts.NumGPU = int(layers)
}
}
}
opts.NumGQA = 0
opts.RopeFrequencyBase = 0.0
opts.RopeFrequencyScale = 0.0
return newLlmServer(info, workDir, model, adapters, projectors, opts)
return newLlmServer(library, model, adapters, projectors, opts)
}
// Give any native cgo implementations an opportunity to initialize
@@ -128,40 +133,15 @@ func Init(workdir string) error {
return nativeInit(workdir)
}
func newLlmServer(gpuInfo gpu.GpuInfo, workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) {
dynLibs := getDynLibs(gpuInfo)
// Check to see if the user has requested a specific library instead of auto-detecting
demandLib := os.Getenv("OLLAMA_LLM_LIBRARY")
if demandLib != "" {
libPath := availableDynLibs[demandLib]
if libPath == "" {
slog.Info(fmt.Sprintf("Invalid OLLAMA_LLM_LIBRARY %s - not found", demandLib))
} else {
slog.Info(fmt.Sprintf("Loading OLLAMA_LLM_LIBRARY=%s", demandLib))
dynLibs = []string{libPath}
}
}
// We stage into a temp directory, and if we've been idle for a while, it may have been reaped
_, err := os.Stat(dynLibs[0])
if err != nil {
slog.Info(fmt.Sprintf("%s has disappeared, reloading libraries", dynLibs[0]))
err = nativeInit(workDir)
if err != nil {
return nil, err
}
}
err2 := fmt.Errorf("unable to locate suitable llm library")
for _, dynLib := range dynLibs {
srv, err := newDynExtServer(dynLib, model, adapters, projectors, opts)
func newLlmServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
if _, libPresent := AvailableShims[library]; libPresent && library != "default" {
srv, err := newDynamicShimExtServer(AvailableShims[library], model, adapters, projectors, opts)
if err == nil {
return srv, nil
}
slog.Warn(fmt.Sprintf("Failed to load dynamic library %s %s", dynLib, err))
err2 = err
log.Printf("Failed to load dynamic library %s - falling back to CPU mode %s", library, err)
// TODO - update some state to indicate we were unable to load the GPU library for future "info" ux
}
return nil, err2
return newDefaultExtServer(model, adapters, projectors, opts)
}

View File

@@ -1,21 +0,0 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index d86d7e04..2694e92e 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -901,13 +901,15 @@ struct llama_server_context
slot.sent_count += result.text_to_send.size();
// add the token to slot queue and cache
}
- slot.add_token_string(result);
+
if (slot.params.stream)
{
send_partial_response(slot, result);
}
}
+ slot.add_token_string(result);
+
if (incomplete)
{
slot.has_next_token = true;

View File

@@ -1,85 +0,0 @@
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 11dd82c3..311495a8 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -28,6 +28,7 @@
#include <chrono>
#include <condition_variable>
#include <atomic>
+#include <signal.h>
using json = nlohmann::json;
@@ -2394,6 +2395,9 @@ static void append_to_generated_text_from_generated_token_probs(llama_server_con
}
}
+std::function<void(int)> shutdown_handler;
+inline void signal_handler(int signal) { shutdown_handler(signal); }
+
int main(int argc, char **argv)
{
#if SERVER_VERBOSE != 1
@@ -3014,8 +3018,14 @@ int main(int argc, char **argv)
std::placeholders::_2,
std::placeholders::_3
));
- llama.queue_tasks.start_loop();
+ shutdown_handler = [&](int) {
+ llama.queue_tasks.terminate();
+ };
+ signal(SIGTERM, signal_handler);
+ signal(SIGINT, signal_handler);
+ llama.queue_tasks.start_loop();
+ svr.stop();
t.join();
llama_backend_free();
diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp
index 70cce072..9124869a 100644
--- a/examples/server/utils.hpp
+++ b/examples/server/utils.hpp
@@ -190,6 +190,7 @@ inline std::string format_chatml(std::vector<json> messages)
struct llama_server_queue {
int id = 0;
std::mutex mutex_tasks;
+ bool running;
// queues
std::vector<task_server> queue_tasks;
std::vector<task_server> queue_tasks_deferred;
@@ -248,9 +249,18 @@ struct llama_server_queue {
queue_tasks_deferred.clear();
}
- // Start the main loop. This call is blocking
- [[noreturn]]
+ // end the start_loop routine
+ void terminate() {
+ {
+ std::unique_lock<std::mutex> lock(mutex_tasks);
+ running = false;
+ }
+ condition_tasks.notify_all();
+ }
+
+ // Start the main loop.
void start_loop() {
+ running = true;
while (true) {
// new task arrived
LOG_VERBOSE("have new task", {});
@@ -294,8 +304,12 @@ struct llama_server_queue {
{
std::unique_lock<std::mutex> lock(mutex_tasks);
if (queue_tasks.empty()) {
+ if (!running) {
+ LOG_VERBOSE("ending start_loop", {});
+ return;
+ }
condition_tasks.wait(lock, [&]{
- return !queue_tasks.empty();
+ return (!queue_tasks.empty() || !running);
});
}
}

View File

@@ -1,284 +0,0 @@
package llm
import (
"compress/gzip"
"errors"
"fmt"
"io"
"io/fs"
"log/slog"
"os"
"path/filepath"
"runtime"
"strings"
"golang.org/x/exp/slices"
"golang.org/x/sync/errgroup"
"github.com/jmorganca/ollama/gpu"
)
// Libraries names may contain an optional variant separated by '_'
// For example, "rocm_v6" and "rocm_v5" or "cpu" and "cpu_avx2"
// Any library without a variant is the lowest common denominator
var availableDynLibs = map[string]string{}
const pathComponentCount = 7
// getDynLibs returns an ordered list of LLM libraries to try, starting with the best
func getDynLibs(gpuInfo gpu.GpuInfo) []string {
// Short circuit if we know we're using the default built-in (darwin only)
if gpuInfo.Library == "default" {
return []string{"default"}
}
// TODO - temporary until we have multiple CPU variations for Darwin
// Short circuit on darwin with metal only
if len(availableDynLibs) == 1 {
if _, onlyMetal := availableDynLibs["metal"]; onlyMetal {
return []string{availableDynLibs["metal"]}
}
}
exactMatch := ""
dynLibs := []string{}
altDynLibs := []string{}
requested := gpuInfo.Library
if gpuInfo.Variant != "" {
requested += "_" + gpuInfo.Variant
}
// Try to find an exact match
for cmp := range availableDynLibs {
if requested == cmp {
exactMatch = cmp
dynLibs = []string{availableDynLibs[cmp]}
break
}
}
// Then for GPUs load alternates and sort the list for consistent load ordering
if gpuInfo.Library != "cpu" {
for cmp := range availableDynLibs {
if gpuInfo.Library == strings.Split(cmp, "_")[0] && cmp != exactMatch {
altDynLibs = append(altDynLibs, cmp)
}
}
slices.Sort(altDynLibs)
for _, altDynLib := range altDynLibs {
dynLibs = append(dynLibs, availableDynLibs[altDynLib])
}
}
// Load up the best CPU variant if not primary requested
if gpuInfo.Library != "cpu" {
variant := gpu.GetCPUVariant()
// If no variant, then we fall back to default
// If we have a variant, try that if we find an exact match
// Attempting to run the wrong CPU instructions will panic the
// process
if variant != "" {
for cmp := range availableDynLibs {
if cmp == "cpu_"+variant {
dynLibs = append(dynLibs, availableDynLibs[cmp])
break
}
}
} else {
dynLibs = append(dynLibs, availableDynLibs["cpu"])
}
}
// Finally, if we didn't find any matches, LCD CPU FTW
if len(dynLibs) == 0 {
dynLibs = []string{availableDynLibs["cpu"]}
}
slog.Debug(fmt.Sprintf("ordered list of LLM libraries to try %v", dynLibs))
return dynLibs
}
func rocmDynLibPresent() bool {
for dynLibName := range availableDynLibs {
if strings.HasPrefix(dynLibName, "rocm") {
return true
}
}
return false
}
func nativeInit(workdir string) error {
slog.Info("Extracting dynamic libraries...")
if runtime.GOOS == "darwin" {
err := extractPayloadFiles(workdir, "llama.cpp/ggml-metal.metal")
if err != nil {
if err == payloadMissing {
// TODO perhaps consider this a hard failure on arm macs?
slog.Info("ggml-meta.metal payload missing")
return nil
}
return err
}
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
}
libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/*/lib/*")
if err != nil {
if err == payloadMissing {
slog.Info(fmt.Sprintf("%s", payloadMissing))
return nil
}
return err
}
for _, lib := range libs {
// The last dir component is the variant name
variant := filepath.Base(filepath.Dir(lib))
availableDynLibs[variant] = lib
}
if err := verifyDriverAccess(); err != nil {
return err
}
// Report which dynamic libraries we have loaded to assist troubleshooting
variants := make([]string, len(availableDynLibs))
i := 0
for variant := range availableDynLibs {
variants[i] = variant
i++
}
slog.Info(fmt.Sprintf("Dynamic LLM libraries %v", variants))
slog.Debug("Override detection logic by setting OLLAMA_LLM_LIBRARY")
return nil
}
func extractDynamicLibs(workDir, glob string) ([]string, error) {
files, err := fs.Glob(libEmbed, glob)
if err != nil || len(files) == 0 {
return nil, payloadMissing
}
libs := []string{}
// TODO consider making this idempotent with some sort of persistent directory (where we store models probably)
// and tracking by version so we don't reexpand the files every time
// Also maybe consider lazy loading only what is needed
g := new(errgroup.Group)
for _, file := range files {
pathComps := strings.Split(file, "/")
if len(pathComps) != pathComponentCount {
slog.Error(fmt.Sprintf("unexpected payload components: %v", pathComps))
continue
}
file := file
g.Go(func() error {
// llama.cpp/build/$OS/$GOARCH/$VARIANT/lib/$LIBRARY
// Include the variant in the path to avoid conflicts between multiple server libs
targetDir := filepath.Join(workDir, pathComps[pathComponentCount-3])
srcFile, err := libEmbed.Open(file)
if err != nil {
return fmt.Errorf("read payload %s: %v", file, err)
}
defer srcFile.Close()
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
}
src := io.Reader(srcFile)
filename := file
if strings.HasSuffix(file, ".gz") {
src, err = gzip.NewReader(src)
if err != nil {
return fmt.Errorf("decompress payload %s: %v", file, err)
}
filename = strings.TrimSuffix(filename, ".gz")
}
destFile := filepath.Join(targetDir, filepath.Base(filename))
if strings.Contains(destFile, "server") {
libs = append(libs, destFile)
}
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("write payload %s: %v", file, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
return fmt.Errorf("copy payload %s: %v", file, err)
}
case err != nil:
return fmt.Errorf("stat payload %s: %v", file, err)
}
return nil
})
}
return libs, g.Wait()
}
func extractPayloadFiles(workDir, glob string) error {
files, err := fs.Glob(libEmbed, glob)
if err != nil || len(files) == 0 {
return payloadMissing
}
for _, file := range files {
srcFile, err := libEmbed.Open(file)
if err != nil {
return fmt.Errorf("read payload %s: %v", file, err)
}
defer srcFile.Close()
if err := os.MkdirAll(workDir, 0o755); err != nil {
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
}
src := io.Reader(srcFile)
filename := file
if strings.HasSuffix(file, ".gz") {
src, err = gzip.NewReader(src)
if err != nil {
return fmt.Errorf("decompress payload %s: %v", file, err)
}
filename = strings.TrimSuffix(filename, ".gz")
}
destFile := filepath.Join(workDir, filepath.Base(filename))
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("write payload %s: %v", file, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, src); err != nil {
return fmt.Errorf("copy payload %s: %v", file, err)
}
case err != nil:
return fmt.Errorf("stat payload %s: %v", file, err)
}
}
return nil
}
func verifyDriverAccess() error {
if runtime.GOOS != "linux" {
return nil
}
// Only check ROCm access if we have the dynamic lib loaded
if rocmDynLibPresent() {
// Verify we have permissions - either running as root, or we have group access to the driver
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
if err != nil {
if errors.Is(err, fs.ErrPermission) {
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
} else if errors.Is(err, fs.ErrNotExist) {
// expected behavior without a radeon card
return nil
}
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
}
fd.Close()
}
return nil
}

View File

@@ -1,8 +0,0 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/x86_64/*/lib/*.dylib*
var libEmbed embed.FS

View File

@@ -1,8 +0,0 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/ggml-metal.metal llama.cpp/build/darwin/arm64/*/lib/*.dylib*
var libEmbed embed.FS

View File

@@ -1,8 +0,0 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/build/linux/*/*/lib/*.so*
var libEmbed embed.FS

View File

@@ -1,58 +0,0 @@
package llm
import (
"testing"
"github.com/jmorganca/ollama/gpu"
"github.com/stretchr/testify/assert"
)
func TestGetDynLibs(t *testing.T) {
availableDynLibs = map[string]string{
"cpu": "X_cpu",
}
assert.Equal(t, false, rocmDynLibPresent())
res := getDynLibs(gpu.GpuInfo{Library: "cpu"})
assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"], res[0])
variant := gpu.GetCPUVariant()
if variant != "" {
variant = "_" + variant
}
availableDynLibs = map[string]string{
"rocm_v5": "X_rocm_v5",
"rocm_v6": "X_rocm_v6",
"cpu" + variant: "X_cpu",
}
assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm"})
assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v5"], res[0])
assert.Equal(t, availableDynLibs["rocm_v6"], res[1])
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 3)
assert.Equal(t, availableDynLibs["rocm_v6"], res[0])
assert.Equal(t, availableDynLibs["rocm_v5"], res[1])
assert.Equal(t, availableDynLibs["cpu"+variant], res[2])
res = getDynLibs(gpu.GpuInfo{Library: "cuda"})
assert.Len(t, res, 1)
assert.Equal(t, availableDynLibs["cpu"+variant], res[0])
res = getDynLibs(gpu.GpuInfo{Library: "default"})
assert.Len(t, res, 1)
assert.Equal(t, "default", res[0])
availableDynLibs = map[string]string{
"rocm": "X_rocm_v5",
"cpu" + variant: "X_cpu",
}
assert.Equal(t, true, rocmDynLibPresent())
res = getDynLibs(gpu.GpuInfo{Library: "rocm", Variant: "v6"})
assert.Len(t, res, 2)
assert.Equal(t, availableDynLibs["rocm"], res[0])
assert.Equal(t, availableDynLibs["cpu"+variant], res[1])
}

View File

@@ -1,8 +0,0 @@
package llm
import (
"embed"
)
//go:embed llama.cpp/build/windows/*/*/lib/*.dll*
var libEmbed embed.FS

71
llm/shim_darwin.go Normal file
View File

@@ -0,0 +1,71 @@
package llm
import (
"embed"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"path/filepath"
"github.com/jmorganca/ollama/api"
)
//go:embed llama.cpp/ggml-metal.metal
var libEmbed embed.FS
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
// should never happen...
return nil, fmt.Errorf("Dynamic library loading not supported on Mac")
}
func nativeInit(workdir string) error {
err := extractPayloadFiles(workdir, "llama.cpp/ggml-metal.metal")
if err != nil {
if err == payloadMissing {
// TODO perhaps consider this a hard failure on arm macs?
log.Printf("ggml-meta.metal payload missing")
return nil
}
return err
}
os.Setenv("GGML_METAL_PATH_RESOURCES", workdir)
return nil
}
func extractPayloadFiles(workDir, glob string) error {
files, err := fs.Glob(libEmbed, glob)
if err != nil || len(files) == 0 {
return payloadMissing
}
for _, file := range files {
srcFile, err := libEmbed.Open(file)
if err != nil {
return fmt.Errorf("read payload %s: %v", file, err)
}
defer srcFile.Close()
if err := os.MkdirAll(workDir, 0o755); err != nil {
return fmt.Errorf("create payload temp dir %s: %v", workDir, err)
}
destFile := filepath.Join(workDir, filepath.Base(file))
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return fmt.Errorf("write payload %s: %v", file, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return fmt.Errorf("copy payload %s: %v", file, err)
}
case err != nil:
return fmt.Errorf("stat payload %s: %v", file, err)
}
}
return nil
}

193
llm/shim_ext_server.go Normal file
View File

@@ -0,0 +1,193 @@
//go:build !darwin
package llm
/*
#include <stdlib.h>
#include "dynamic_shim.h"
*/
import "C"
import (
"context"
"errors"
"fmt"
"io"
"io/fs"
"log"
"os"
"path/filepath"
"strings"
"sync"
"unsafe"
"github.com/jmorganca/ollama/api"
)
type shimExtServer struct {
s C.struct_dynamic_llama_server
options api.Options
}
// Note: current implementation does not support concurrent instantiations
var shimMutex sync.Mutex
var llm *shimExtServer
const pathComponentCount = 6
func (llm *shimExtServer) llama_server_init(sparams *C.ext_server_params_t, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_init(llm.s, sparams, err)
}
func (llm *shimExtServer) llama_server_start() {
C.dynamic_shim_llama_server_start(llm.s)
}
func (llm *shimExtServer) llama_server_stop() {
C.dynamic_shim_llama_server_stop(llm.s)
}
func (llm *shimExtServer) llama_server_completion(json_req *C.char, resp *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_completion(llm.s, json_req, resp)
}
func (llm *shimExtServer) llama_server_completion_next_result(task_id C.int, resp *C.ext_server_task_result_t) {
C.dynamic_shim_llama_server_completion_next_result(llm.s, task_id, resp)
}
func (llm *shimExtServer) llama_server_completion_cancel(task_id C.int, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_completion_cancel(llm.s, task_id, err)
}
func (llm *shimExtServer) llama_server_release_task_result(result *C.ext_server_task_result_t) {
C.dynamic_shim_llama_server_release_task_result(llm.s, result)
}
func (llm *shimExtServer) llama_server_tokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_tokenize(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_detokenize(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_detokenize(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_embedding(json_req *C.char, json_resp **C.char, err *C.ext_server_resp_t) {
C.dynamic_shim_llama_server_embedding(llm.s, json_req, json_resp, err)
}
func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) {
C.dynamic_shim_llama_server_release_json_resp(llm.s, json_resp)
}
func newDynamicShimExtServer(library, model string, adapters, projectors []string, opts api.Options) (extServer, error) {
shimMutex.Lock()
defer shimMutex.Unlock()
updatePath(filepath.Dir(library))
libPath := C.CString(library)
defer C.free(unsafe.Pointer(libPath))
resp := newExtServerResp(128)
defer freeExtServerResp(resp)
var srv C.struct_dynamic_llama_server
C.dynamic_shim_init(libPath, &srv, &resp)
if resp.id < 0 {
return nil, fmt.Errorf("Unable to load dynamic library: %s", C.GoString(resp.msg))
}
llm = &shimExtServer{
s: srv,
options: opts,
}
log.Printf("Loading Dynamic Shim llm server: %s", library)
return newExtServer(llm, model, adapters, projectors, opts)
}
func (llm *shimExtServer) Predict(ctx context.Context, pred PredictOpts, fn func(PredictResult)) error {
return predict(ctx, llm, pred, fn)
}
func (llm *shimExtServer) Encode(ctx context.Context, prompt string) ([]int, error) {
return encode(llm, ctx, prompt)
}
func (llm *shimExtServer) Decode(ctx context.Context, tokens []int) (string, error) {
return decode(llm, ctx, tokens)
}
func (llm *shimExtServer) Embedding(ctx context.Context, input string) ([]float64, error) {
return embedding(llm, ctx, input)
}
func (llm *shimExtServer) Close() {
close(llm)
}
func nativeInit(workdir string) error {
libs, err := extractDynamicLibs(workdir, "llama.cpp/build/*/*/lib/*")
if err != nil {
if err == payloadMissing {
log.Printf("%s", payloadMissing)
return nil
}
return err
}
for _, lib := range libs {
// The last dir component is the variant name
variant := filepath.Base(filepath.Dir(lib))
AvailableShims[variant] = lib
}
if err := verifyDriverAccess(); err != nil {
return err
}
// Report which dynamic libraries we have loaded to assist troubleshooting
variants := make([]string, len(AvailableShims))
i := 0
for variant := range AvailableShims {
variants[i] = variant
i++
}
log.Printf("Dynamic LLM variants %v", variants)
return nil
}
func extractDynamicLibs(workDir, glob string) ([]string, error) {
files, err := fs.Glob(libEmbed, glob)
if err != nil || len(files) == 0 {
return nil, payloadMissing
}
libs := []string{}
for _, file := range files {
pathComps := strings.Split(file, "/")
if len(pathComps) != pathComponentCount {
log.Printf("unexpected payload components: %v", pathComps)
continue
}
// llama.cpp/build/$OS/$VARIANT/lib/$LIBRARY
// Include the variant in the path to avoid conflicts between multiple server libs
targetDir := filepath.Join(workDir, pathComps[pathComponentCount-3])
srcFile, err := libEmbed.Open(file)
if err != nil {
return nil, fmt.Errorf("read payload %s: %v", file, err)
}
defer srcFile.Close()
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return nil, fmt.Errorf("create payload temp dir %s: %v", workDir, err)
}
destFile := filepath.Join(targetDir, filepath.Base(file))
if strings.Contains(destFile, "server") {
libs = append(libs, destFile)
}
_, err = os.Stat(destFile)
switch {
case errors.Is(err, os.ErrNotExist):
destFile, err := os.OpenFile(destFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o755)
if err != nil {
return nil, fmt.Errorf("write payload %s: %v", file, err)
}
defer destFile.Close()
if _, err := io.Copy(destFile, srcFile); err != nil {
return nil, fmt.Errorf("copy payload %s: %v", file, err)
}
case err != nil:
return nil, fmt.Errorf("stat payload %s: %v", file, err)
}
}
return libs, nil
}

View File

@@ -0,0 +1,46 @@
package llm
import (
"embed"
"errors"
"fmt"
"io/fs"
"log"
"os"
"strings"
)
//go:embed llama.cpp/build/*/*/lib/*.so
var libEmbed embed.FS
func updatePath(dir string) {
pathComponents := strings.Split(os.Getenv("PATH"), ":")
for _, comp := range pathComponents {
if comp == dir {
return
}
}
newPath := strings.Join(append(pathComponents, dir), ":")
log.Printf("Updating PATH to %s", newPath)
os.Setenv("PATH", newPath)
}
func verifyDriverAccess() error {
// Only check ROCm access if we have the dynamic lib loaded
if _, rocmPresent := AvailableShims["rocm"]; rocmPresent {
// Verify we have permissions - either running as root, or we have group access to the driver
fd, err := os.OpenFile("/dev/kfd", os.O_RDWR, 0666)
if err != nil {
if errors.Is(err, fs.ErrPermission) {
return fmt.Errorf("Radeon card detected, but permissions not set up properly. Either run ollama as root, or add you user account to the render group.")
} else if errors.Is(err, fs.ErrNotExist) {
// expected behavior without a radeon card
return nil
}
return fmt.Errorf("failed to check permission on /dev/kfd: %w", err)
}
fd.Close()
}
return nil
}

View File

@@ -0,0 +1,36 @@
package llm
import (
"embed"
"log"
"os"
"path/filepath"
"strings"
)
//go:embed llama.cpp/build/windows/*/lib/*.dll
var libEmbed embed.FS
func updatePath(dir string) {
tmpDir := filepath.Dir(dir)
pathComponents := strings.Split(os.Getenv("PATH"), ";")
i := 0
for _, comp := range pathComponents {
if strings.EqualFold(comp, dir) {
return
}
// Remove any other prior paths to our temp dir
if !strings.HasPrefix(strings.ToLower(comp), strings.ToLower(tmpDir)) {
pathComponents[i] = comp
i++
}
}
newPath := strings.Join(append([]string{dir}, pathComponents...), ";")
log.Printf("Updating PATH to %s", newPath)
os.Setenv("PATH", newPath)
}
func verifyDriverAccess() error {
// TODO if applicable
return nil
}

View File

@@ -1,322 +0,0 @@
// openai package provides middleware for partial compatibility with the OpenAI REST API
package openai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/jmorganca/ollama/api"
)
type Error struct {
Message string `json:"message"`
Type string `json:"type"`
Param interface{} `json:"param"`
Code *string `json:"code"`
}
type ErrorResponse struct {
Error Error `json:"error"`
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
type Choice struct {
Index int `json:"index"`
Message Message `json:"message"`
FinishReason *string `json:"finish_reason"`
}
type ChunkChoice struct {
Index int `json:"index"`
Delta Message `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
type Usage struct {
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
TotalTokens int `json:"total_tokens"`
}
type ResponseFormat struct {
Type string `json:"type"`
}
type ChatCompletionRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Stream bool `json:"stream"`
MaxTokens *int `json:"max_tokens"`
Seed *int `json:"seed"`
Stop any `json:"stop"`
Temperature *float64 `json:"temperature"`
FrequencyPenalty *float64 `json:"frequency_penalty"`
PresencePenalty *float64 `json:"presence_penalty_penalty"`
TopP *float64 `json:"top_p"`
ResponseFormat *ResponseFormat `json:"response_format"`
}
type ChatCompletion struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []Choice `json:"choices"`
Usage Usage `json:"usage,omitempty"`
}
type ChatCompletionChunk struct {
Id string `json:"id"`
Object string `json:"object"`
Created int64 `json:"created"`
Model string `json:"model"`
SystemFingerprint string `json:"system_fingerprint"`
Choices []ChunkChoice `json:"choices"`
}
func NewError(code int, message string) ErrorResponse {
var etype string
switch code {
case http.StatusBadRequest:
etype = "invalid_request_error"
case http.StatusNotFound:
etype = "not_found_error"
default:
etype = "api_error"
}
return ErrorResponse{Error{Type: etype, Message: message}}
}
func toChatCompletion(id string, r api.ChatResponse) ChatCompletion {
return ChatCompletion{
Id: id,
Object: "chat.completion",
Created: r.CreatedAt.Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []Choice{{
Index: 0,
Message: Message{Role: r.Message.Role, Content: r.Message.Content},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
}},
Usage: Usage{
// TODO: ollama returns 0 for prompt eval if the prompt was cached, but openai returns the actual count
PromptTokens: r.PromptEvalCount,
CompletionTokens: r.EvalCount,
TotalTokens: r.PromptEvalCount + r.EvalCount,
},
}
}
func toChunk(id string, r api.ChatResponse) ChatCompletionChunk {
return ChatCompletionChunk{
Id: id,
Object: "chat.completion.chunk",
Created: time.Now().Unix(),
Model: r.Model,
SystemFingerprint: "fp_ollama",
Choices: []ChunkChoice{
{
Index: 0,
Delta: Message{Role: "assistant", Content: r.Message.Content},
FinishReason: func(done bool) *string {
if done {
reason := "stop"
return &reason
}
return nil
}(r.Done),
},
},
}
}
func fromRequest(r ChatCompletionRequest) api.ChatRequest {
var messages []api.Message
for _, msg := range r.Messages {
messages = append(messages, api.Message{Role: msg.Role, Content: msg.Content})
}
options := make(map[string]interface{})
switch stop := r.Stop.(type) {
case string:
options["stop"] = []string{stop}
case []interface{}:
var stops []string
for _, s := range stop {
if str, ok := s.(string); ok {
stops = append(stops, str)
}
}
options["stop"] = stops
}
if r.MaxTokens != nil {
options["num_predict"] = *r.MaxTokens
}
if r.Temperature != nil {
options["temperature"] = *r.Temperature * 2.0
} else {
options["temperature"] = 1.0
}
if r.Seed != nil {
options["seed"] = *r.Seed
// temperature=0 is required for reproducible outputs
options["temperature"] = 0.0
}
if r.FrequencyPenalty != nil {
options["frequency_penalty"] = *r.FrequencyPenalty * 2.0
}
if r.PresencePenalty != nil {
options["presence_penalty"] = *r.PresencePenalty * 2.0
}
if r.TopP != nil {
options["top_p"] = *r.TopP
} else {
options["top_p"] = 1.0
}
var format string
if r.ResponseFormat != nil && r.ResponseFormat.Type == "json_object" {
format = "json"
}
return api.ChatRequest{
Model: r.Model,
Messages: messages,
Format: format,
Options: options,
Stream: &r.Stream,
}
}
type writer struct {
stream bool
id string
gin.ResponseWriter
}
func (w *writer) writeError(code int, data []byte) (int, error) {
var serr api.StatusError
err := json.Unmarshal(data, &serr)
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(NewError(http.StatusInternalServerError, serr.Error()))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *writer) writeResponse(data []byte) (int, error) {
var chatResponse api.ChatResponse
err := json.Unmarshal(data, &chatResponse)
if err != nil {
return 0, err
}
// chat chunk
if w.stream {
d, err := json.Marshal(toChunk(w.id, chatResponse))
if err != nil {
return 0, err
}
w.ResponseWriter.Header().Set("Content-Type", "text/event-stream")
_, err = w.ResponseWriter.Write([]byte(fmt.Sprintf("data: %s\n\n", d)))
if err != nil {
return 0, err
}
if chatResponse.Done {
_, err = w.ResponseWriter.Write([]byte("data: [DONE]\n\n"))
if err != nil {
return 0, err
}
}
return len(data), nil
}
// chat completion
w.ResponseWriter.Header().Set("Content-Type", "application/json")
err = json.NewEncoder(w.ResponseWriter).Encode(toChatCompletion(w.id, chatResponse))
if err != nil {
return 0, err
}
return len(data), nil
}
func (w *writer) Write(data []byte) (int, error) {
code := w.ResponseWriter.Status()
if code != http.StatusOK {
return w.writeError(code, data)
}
return w.writeResponse(data)
}
func Middleware() gin.HandlerFunc {
return func(c *gin.Context) {
var req ChatCompletionRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, err.Error()))
return
}
if len(req.Messages) == 0 {
c.AbortWithStatusJSON(http.StatusBadRequest, NewError(http.StatusBadRequest, "[] is too short - 'messages'"))
return
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(fromRequest(req)); err != nil {
c.AbortWithStatusJSON(http.StatusInternalServerError, NewError(http.StatusInternalServerError, err.Error()))
return
}
c.Request.Body = io.NopCloser(&b)
w := &writer{
ResponseWriter: c.Writer,
stream: req.Stream,
id: fmt.Sprintf("chatcmpl-%d", rand.Intn(999)),
}
c.Writer = w
c.Next()
}
}

View File

@@ -6,8 +6,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"slices"
"log"
)
type Command struct {
@@ -57,20 +56,10 @@ func Parse(reader io.Reader) ([]Command, error) {
command.Args = string(bytes.TrimSpace(fields[1]))
case "EMBED":
return nil, fmt.Errorf("deprecated command: EMBED is no longer supported, use the /embed API endpoint instead")
case "MESSAGE":
command.Name = string(bytes.ToLower(fields[0]))
fields = bytes.SplitN(fields[1], []byte(" "), 2)
if len(fields) < 2 {
return nil, fmt.Errorf("should be in the format <role> <message>")
}
if !slices.Contains([]string{"system", "user", "assistant"}, string(bytes.ToLower(fields[0]))) {
return nil, fmt.Errorf("role must be one of \"system\", \"user\", or \"assistant\"")
}
command.Args = fmt.Sprintf("%s: %s", string(bytes.ToLower(fields[0])), string(fields[1]))
default:
if !bytes.HasPrefix(fields[0], []byte("#")) {
// log a warning for unknown commands
slog.Warn(fmt.Sprintf("Unknown command: %s", fields[0]))
log.Printf("WARNING: Unknown command: %s", fields[0])
}
continue
}

View File

@@ -61,38 +61,3 @@ PARAMETER param1
assert.ErrorContains(t, err, "missing value for [param1]")
}
func Test_Parser_Messages(t *testing.T) {
input := `
FROM foo
MESSAGE system You are a Parser. Always Parse things.
MESSAGE user Hey there!
MESSAGE assistant Hello, I want to parse all the things!
`
reader := strings.NewReader(input)
commands, err := Parse(reader)
assert.Nil(t, err)
expectedCommands := []Command{
{Name: "model", Args: "foo"},
{Name: "message", Args: "system: You are a Parser. Always Parse things."},
{Name: "message", Args: "user: Hey there!"},
{Name: "message", Args: "assistant: Hello, I want to parse all the things!"},
}
assert.Equal(t, expectedCommands, commands)
}
func Test_Parser_Messages_BadRole(t *testing.T) {
input := `
FROM foo
MESSAGE badguy I'm a bad guy!
`
reader := strings.NewReader(input)
_, err := Parse(reader)
assert.ErrorContains(t, err, "role must be one of \"system\", \"user\", or \"assistant\"")
}

View File

@@ -77,7 +77,7 @@ func (p *Progress) Add(key string, state State) {
p.states = append(p.states, state)
}
func (p *Progress) render() {
func (p *Progress) render() error {
p.mu.Lock()
defer p.mu.Unlock()
@@ -101,6 +101,8 @@ func (p *Progress) render() {
}
p.pos = len(p.states)
return nil
}
func (p *Progress) start() {

View File

@@ -133,6 +133,13 @@ func (b *Buffer) Size() int {
return b.Buf.Size()
}
func min(n, m int) int {
if n > m {
return m
}
return n
}
func (b *Buffer) Add(r rune) {
if b.Pos == b.Buf.Size() {
fmt.Printf("%c", r)

View File

@@ -23,7 +23,7 @@ type History struct {
func NewHistory() (*History, error) {
h := &History{
Buf: arraylist.New(),
Limit: 100, // resizeme
Limit: 100, //resizeme
Autosave: true,
Enabled: true,
}
@@ -49,7 +49,7 @@ func (h *History) Init() error {
h.Filename = path
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0o600)
f, err := os.OpenFile(path, os.O_CREATE|os.O_RDONLY, 0600)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return nil
@@ -84,7 +84,7 @@ func (h *History) Add(l []rune) {
h.Compact()
h.Pos = h.Size()
if h.Autosave {
_ = h.Save()
h.Save()
}
}
@@ -132,7 +132,7 @@ func (h *History) Save() error {
tmpFile := h.Filename + ".tmp"
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0o600)
f, err := os.OpenFile(tmpFile, os.O_CREATE|os.O_WRONLY|os.O_TRUNC|os.O_APPEND, 0666)
if err != nil {
return err
}

View File

@@ -32,8 +32,6 @@ func (p *Prompt) placeholder() string {
type Terminal struct {
outchan chan rune
rawmode bool
termios any
}
type Instance struct {
@@ -62,16 +60,6 @@ func New(prompt Prompt) (*Instance, error) {
}
func (i *Instance) Readline() (string, error) {
if !i.Terminal.rawmode {
fd := int(syscall.Stdin)
termios, err := SetRawMode(fd)
if err != nil {
return "", err
}
i.Terminal.rawmode = true
i.Terminal.termios = termios
}
prompt := i.Prompt.prompt()
if i.Pasting {
// force alt prompt when pasting
@@ -79,12 +67,12 @@ func (i *Instance) Readline() (string, error) {
}
fmt.Print(prompt)
defer func() {
fd := int(syscall.Stdin)
// nolint: errcheck
UnsetRawMode(fd, i.Terminal.termios)
i.Terminal.rawmode = false
}()
fd := int(syscall.Stdin)
termios, err := SetRawMode(fd)
if err != nil {
return "", err
}
defer UnsetRawMode(fd, termios)
buf, _ := NewBuffer(i.Prompt)
@@ -216,8 +204,7 @@ func (i *Instance) Readline() (string, error) {
case CharCtrlW:
buf.DeleteWord()
case CharCtrlZ:
fd := int(syscall.Stdin)
return handleCharCtrlZ(fd, i.Terminal.termios)
return handleCharCtrlZ(fd, termios)
case CharEnter:
output := buf.String()
if output != "" {
@@ -248,16 +235,8 @@ func (i *Instance) HistoryDisable() {
}
func NewTerminal() (*Terminal, error) {
fd := int(syscall.Stdin)
termios, err := SetRawMode(fd)
if err != nil {
return nil, err
}
t := &Terminal{
outchan: make(chan rune),
rawmode: true,
termios: termios,
}
go t.ioloop()

View File

@@ -6,13 +6,12 @@ import (
"syscall"
)
func handleCharCtrlZ(fd int, termios any) (string, error) {
t := termios.(*Termios)
if err := UnsetRawMode(fd, t); err != nil {
func handleCharCtrlZ(fd int, termios *Termios) (string, error) {
if err := UnsetRawMode(fd, termios); err != nil {
return "", err
}
_ = syscall.Kill(0, syscall.SIGSTOP)
syscall.Kill(0, syscall.SIGSTOP)
// on resume...
return "", nil

View File

@@ -1,6 +1,6 @@
package readline
func handleCharCtrlZ(fd int, state any) (string, error) {
func handleCharCtrlZ(fd int, state *State) (string, error) {
// not supported
return "", nil
}

View File

@@ -25,9 +25,8 @@ func SetRawMode(fd int) (*Termios, error) {
return termios, setTermios(fd, &newTermios)
}
func UnsetRawMode(fd int, termios any) error {
t := termios.(*Termios)
return setTermios(fd, t)
func UnsetRawMode(fd int, termios *Termios) error {
return setTermios(fd, termios)
}
// IsTerminal returns true if the given file descriptor is a terminal.

View File

@@ -56,8 +56,7 @@ func SetRawMode(fd int) (*State, error) {
return &State{st}, nil
}
func UnsetRawMode(fd int, state any) error {
s := state.(*State)
_, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(s.mode), 0)
func UnsetRawMode(fd int, state *State) error {
_, _, err := syscall.SyscallN(procSetConsoleMode.Addr(), uintptr(fd), uintptr(state.mode), 0)
return err
}

View File

@@ -1,8 +1,8 @@
#!/bin/sh
set -e
set -eu
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export VERSION=${VERSION:-0.0.0}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
mkdir -p dist
@@ -11,36 +11,21 @@ for TARGETARCH in arm64 amd64; do
rm -rf llm/llama.cpp/build
GOOS=darwin GOARCH=$TARGETARCH go generate ./...
CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -o dist/ollama-darwin-$TARGETARCH
CGO_ENABLED=1 GOOS=darwin GOARCH=$TARGETARCH go build -cover -o dist/ollama-darwin-$TARGETARCH-cov
done
lipo -create -output dist/ollama dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
rm -f dist/ollama-darwin-arm64 dist/ollama-darwin-amd64
if [ -n "$APPLE_IDENTITY" ]; then
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
else
echo "Skipping code signing - set APPLE_IDENTITY"
fi
lipo -create -output dist/ollama dist/ollama-darwin-*
rm -f dist/ollama-darwin-*
codesign --deep --force --options=runtime --sign "$APPLE_IDENTITY" --timestamp dist/ollama
chmod +x dist/ollama
# build and optionally sign the mac app
# build and sign the mac app
npm install --prefix app
if [ -n "$APPLE_IDENTITY" ]; then
npm run --prefix app make:sign
else
npm run --prefix app make
fi
npm run --prefix app make:sign
cp app/out/make/zip/darwin/universal/Ollama-darwin-universal-$VERSION.zip dist/Ollama-darwin.zip
# sign the binary and rename it
if [ -n "$APPLE_IDENTITY" ]; then
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
else
echo "WARNING: Skipping code signing - set APPLE_IDENTITY"
fi
codesign -f --timestamp -s "$APPLE_IDENTITY" --identifier ai.ollama.ollama --options=runtime dist/ollama
ditto -c -k --keepParent dist/ollama dist/temp.zip
if [ -n "$APPLE_IDENTITY" ]; then
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
fi
xcrun notarytool submit dist/temp.zip --wait --timeout 10m --apple-id $APPLE_ID --password $APPLE_PASSWORD --team-id $APPLE_TEAM_ID
mv dist/ollama dist/ollama-darwin
rm -f dist/temp.zip

View File

@@ -2,7 +2,7 @@
set -eu
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export VERSION=${VERSION:-0.0.0}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
docker build \
@@ -13,13 +13,3 @@ docker build \
-f Dockerfile \
-t ollama/ollama:$VERSION \
.
docker build \
--load \
--platform=linux/amd64 \
--build-arg=VERSION \
--build-arg=GOFLAGS \
--target runtime-rocm \
-f Dockerfile \
-t ollama/ollama:$VERSION-rocm \
.

View File

@@ -2,24 +2,13 @@
set -eu
export VERSION=${VERSION:-$(git describe --tags --first-parent --abbrev=7 --long --dirty --always | sed -e "s/^v//g")}
export VERSION=${VERSION:-0.0.0}
export GOFLAGS="'-ldflags=-w -s \"-X=github.com/jmorganca/ollama/version.Version=$VERSION\" \"-X=github.com/jmorganca/ollama/server.mode=release\"'"
BUILD_ARCH=${BUILD_ARCH:-"amd64 arm64"}
export AMDGPU_TARGETS=${AMDGPU_TARGETS:=""}
mkdir -p dist
for TARGETARCH in ${BUILD_ARCH}; do
docker build \
--platform=linux/$TARGETARCH \
--build-arg=GOFLAGS \
--build-arg=CGO_CFLAGS \
--build-arg=OLLAMA_CUSTOM_CPU_DEFS \
--build-arg=AMDGPU_TARGETS \
--target build-$TARGETARCH \
-f Dockerfile \
-t builder:$TARGETARCH \
.
for TARGETARCH in amd64 arm64; do
docker build --platform=linux/$TARGETARCH --build-arg=GOFLAGS --build-arg=CGO_CFLAGS -f Dockerfile.build -t builder:$TARGETARCH .
docker create --platform linux/$TARGETARCH --name builder-$TARGETARCH builder:$TARGETARCH
docker cp builder-$TARGETARCH:/go/src/github.com/jmorganca/ollama/ollama ./dist/ollama-linux-$TARGETARCH
docker rm builder-$TARGETARCH

View File

@@ -66,7 +66,3 @@ subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'generate', './...
print("Building")
subprocess.check_call(['ssh', netloc, 'cd', path, ';', GoCmd, 'build', '.'])
print("Copying built result")
subprocess.check_call(['scp', netloc +":"+ path + "/ollama.exe", './dist/'])

View File

@@ -61,7 +61,7 @@ if [ -n "$NEEDS" ]; then
fi
status "Downloading ollama..."
curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.com/download/ollama-linux-$ARCH"
curl --fail --show-error --location --progress-bar -o $TEMP_DIR/ollama "https://ollama.ai/download/ollama-linux-$ARCH"
for BINDIR in /usr/local/bin /usr/bin /bin; do
echo $PATH | grep -q $BINDIR && break || continue
@@ -231,8 +231,8 @@ if ! check_gpu nvidia-smi || [ -z "$(nvidia-smi | grep -o "CUDA Version: [0-9]*\
case $OS_NAME in
centos|rhel) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -d '.' -f 1) ;;
rocky) install_cuda_driver_yum 'rhel' $(echo $OS_VERSION | cut -c1) ;;
fedora) [ $OS_VERSION -lt '37' ] && install_cuda_driver_yum $OS_NAME $OS_VERSION || install_cuda_driver_yum $OS_NAME '37';;
amzn) install_cuda_driver_yum 'fedora' '37' ;;
fedora) install_cuda_driver_yum $OS_NAME $OS_VERSION ;;
amzn) install_cuda_driver_yum 'fedora' '35' ;;
debian) install_cuda_driver_apt $OS_NAME $OS_VERSION ;;
ubuntu) install_cuda_driver_apt $OS_NAME $(echo $OS_VERSION | sed 's/\.//') ;;
*) exit ;;

View File

@@ -1,43 +0,0 @@
#!/bin/sh
# Script for common Dockerfile dependency installation in redhat linux based images
set -ex
MACHINE=$(uname -m)
if grep -i "centos" /etc/system-release >/dev/null; then
# Centos 7 derivatives have too old of a git version to run our generate script
# uninstall and ignore failures
yum remove -y git
yum -y install epel-release centos-release-scl
yum -y install dnf
if [ "${MACHINE}" = "x86_64" ]; then
yum -y install https://repo.ius.io/ius-release-el7.rpm
dnf install -y git236
else
dnf install -y rh-git227-git
ln -s /opt/rh/rh-git227/root/usr/bin/git /usr/local/bin/git
fi
dnf install -y devtoolset-10-gcc devtoolset-10-gcc-c++
elif grep -i "rocky" /etc/system-release >/dev/null; then
dnf install -y git gcc-toolset-10-gcc gcc-toolset-10-gcc-c++
else
echo "ERROR Unexpected distro"
exit 1
fi
if [ -n "${CMAKE_VERSION}" ]; then
curl -s -L https://github.com/Kitware/CMake/releases/download/v${CMAKE_VERSION}/cmake-${CMAKE_VERSION}-linux-$(uname -m).tar.gz | tar -zx -C /usr --strip-components 1
fi
if [ -n "${GOLANG_VERSION}" ]; then
if [ "${MACHINE}" = "x86_64" ]; then
GO_ARCH="amd64"
else
GO_ARCH="arm64"
fi
mkdir -p /usr/local
curl -s -L https://dl.google.com/go/go${GOLANG_VERSION}.linux-${GO_ARCH}.tar.gz | tar xz -C /usr/local
ln -s /usr/local/go/bin/go /usr/local/bin/go
ln -s /usr/local/go/bin/gofmt /usr/local/bin/gofmt
fi

View File

@@ -10,7 +10,7 @@ import (
"encoding/json"
"fmt"
"io"
"log/slog"
"log"
"net/http"
"net/url"
"os"
@@ -86,7 +86,7 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
rawKey, err := os.ReadFile(keyPath)
if err != nil {
slog.Info(fmt.Sprintf("Failed to load private key: %v", err))
log.Printf("Failed to load private key: %v", err)
return "", err
}
@@ -105,20 +105,14 @@ func getAuthToken(ctx context.Context, redirData AuthRedirect) (string, error) {
headers.Set("Authorization", sig)
resp, err := makeRequest(ctx, http.MethodGet, redirectURL, headers, nil, nil)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get token: %q", err))
log.Printf("couldn't get token: %q", err)
return "", err
}
defer resp.Body.Close()
if resp.StatusCode >= http.StatusBadRequest {
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return "", fmt.Errorf("%d: %v", resp.StatusCode, err)
} else if len(responseBody) > 0 {
return "", fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
}
return "", fmt.Errorf("%s", resp.Status)
body, _ := io.ReadAll(resp.Body)
return "", fmt.Errorf("on pull registry responded with code %d: %s", resp.StatusCode, body)
}
respBody, err := io.ReadAll(resp.Body)
@@ -153,7 +147,12 @@ func (s SignatureData) Bytes() []byte {
// SignData takes a SignatureData object and signs it with a raw private key
func (s SignatureData) Sign(rawKey []byte) (string, error) {
signer, err := ssh.ParsePrivateKey(rawKey)
privateKey, err := ssh.ParseRawPrivateKey(rawKey)
if err != nil {
return "", err
}
signer, err := ssh.NewSignerFromKey(privateKey)
if err != nil {
return "", err
}

View File

@@ -6,7 +6,7 @@ import (
"errors"
"fmt"
"io"
"log/slog"
"log"
"math"
"net/http"
"net/url"
@@ -25,11 +25,6 @@ import (
"github.com/jmorganca/ollama/format"
)
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
var errPartStalled = errors.New("part stalled")
var blobDownloadManager sync.Map
type blobDownload struct {
@@ -49,11 +44,10 @@ type blobDownload struct {
}
type blobDownloadPart struct {
N int
Offset int64
Size int64
Completed int64
lastUpdated time.Time
N int
Offset int64
Size int64
Completed int64
*blobDownload `json:"-"`
}
@@ -78,13 +72,6 @@ func (p *blobDownloadPart) StopsAt() int64 {
return p.Offset + p.Size
}
func (p *blobDownloadPart) Write(b []byte) (n int, err error) {
n = len(b)
p.blobDownload.Completed.Add(int64(n))
p.lastUpdated = time.Now()
return n, nil
}
func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *RegistryOptions) error {
partFilePaths, err := filepath.Glob(b.Name + "-partial-*")
if err != nil {
@@ -111,7 +98,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
b.Total, _ = strconv.ParseInt(resp.Header.Get("Content-Length"), 10, 64)
size := b.Total / numDownloadParts
var size = b.Total / numDownloadParts
switch {
case size < minDownloadPartSize:
size = minDownloadPartSize
@@ -133,7 +120,7 @@ func (b *blobDownload) Prepare(ctx context.Context, requestURL *url.URL, opts *R
}
}
slog.Info(fmt.Sprintf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
log.Printf("downloading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
return nil
}
@@ -145,13 +132,13 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
defer blobDownloadManager.Delete(b.Digest)
ctx, b.CancelFunc = context.WithCancel(ctx)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0o644)
file, err := os.OpenFile(b.Name+"-partial", os.O_CREATE|os.O_RDWR, 0644)
if err != nil {
return err
}
defer file.Close()
_ = file.Truncate(b.Total)
file.Truncate(b.Total)
g, inner := errgroup.WithContext(ctx)
g.SetLimit(numDownloadParts)
@@ -170,12 +157,9 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
case errors.Is(err, context.Canceled), errors.Is(err, syscall.ENOSPC):
// return immediately if the context is canceled or the device is out of space
return err
case errors.Is(err, errPartStalled):
try--
continue
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue
default:
@@ -211,54 +195,28 @@ func (b *blobDownload) run(ctx context.Context, requestURL *url.URL, opts *Regis
}
func (b *blobDownload) downloadChunk(ctx context.Context, requestURL *url.URL, w io.Writer, part *blobDownloadPart, opts *RegistryOptions) error {
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return err
}
defer resp.Body.Close()
n, err := io.Copy(w, io.TeeReader(resp.Body, part))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
}
part.Completed += n
if err := b.writePart(part.Name(), part); err != nil {
return err
}
// return nil or context.Canceled or UnexpectedEOF (resumable)
headers := make(http.Header)
headers.Set("Range", fmt.Sprintf("bytes=%d-%d", part.StartsAt(), part.StopsAt()-1))
resp, err := makeRequestWithRetry(ctx, http.MethodGet, requestURL, headers, nil, opts)
if err != nil {
return err
})
}
defer resp.Body.Close()
g.Go(func() error {
ticker := time.NewTicker(time.Second)
for {
select {
case <-ticker.C:
if part.Completed >= part.Size {
return nil
}
n, err := io.Copy(w, io.TeeReader(resp.Body, b))
if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, io.ErrUnexpectedEOF) {
// rollback progress
b.Completed.Add(-n)
return err
}
if !part.lastUpdated.IsZero() && time.Since(part.lastUpdated) > 5*time.Second {
slog.Info(fmt.Sprintf("%s part %d stalled; retrying", b.Digest[7:19], part.N))
// reset last updated
part.lastUpdated = time.Time{}
return errPartStalled
}
case <-ctx.Done():
return ctx.Err()
}
}
})
part.Completed += n
if err := b.writePart(part.Name(), part); err != nil {
return err
}
return g.Wait()
// return nil or context.Canceled or UnexpectedEOF (resumable)
return err
}
func (b *blobDownload) newPart(offset, size int64) error {
@@ -288,7 +246,7 @@ func (b *blobDownload) readPart(partName string) (*blobDownloadPart, error) {
}
func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error {
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0o644)
partFile, err := os.OpenFile(partName, os.O_CREATE|os.O_RDWR|os.O_TRUNC, 0644)
if err != nil {
return err
}
@@ -297,6 +255,12 @@ func (b *blobDownload) writePart(partName string, part *blobDownloadPart) error
return json.NewEncoder(partFile).Encode(part)
}
func (b *blobDownload) Write(p []byte) (n int, err error) {
n = len(p)
b.Completed.Add(int64(n))
return n, nil
}
func (b *blobDownload) acquire() {
b.references.Add(1)
}
@@ -315,19 +279,20 @@ func (b *blobDownload) Wait(ctx context.Context, fn func(api.ProgressResponse))
for {
select {
case <-ticker.C:
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
case <-ctx.Done():
return ctx.Err()
}
fn(api.ProgressResponse{
Status: fmt.Sprintf("pulling %s", b.Digest[7:19]),
Digest: b.Digest,
Total: b.Total,
Completed: b.Completed.Load(),
})
if b.done || b.err != nil {
return b.err
}
}
}
@@ -338,6 +303,10 @@ type downloadOpts struct {
fn func(api.ProgressResponse)
}
const maxRetries = 6
var errMaxRetriesExceeded = errors.New("max retries exceeded")
// downloadBlob downloads a blob from the registry and stores it in the blobs directory
func downloadBlob(ctx context.Context, opts downloadOpts) error {
fp, err := GetBlobsPath(opts.digest)
@@ -371,7 +340,6 @@ func downloadBlob(ctx context.Context, opts downloadOpts) error {
return err
}
// nolint: contextcheck
go download.Run(context.Background(), requestURL, opts.regOpts)
}

View File

@@ -10,7 +10,6 @@ import (
"fmt"
"io"
"log"
"log/slog"
"net/http"
"net/url"
"os"
@@ -19,6 +18,7 @@ import (
"strconv"
"strings"
"text/template"
"text/template/parse"
"golang.org/x/exp/slices"
@@ -40,7 +40,7 @@ type Model struct {
Config ConfigV2
ShortName string
ModelPath string
ParentModel string
OriginalModel string
AdapterPaths []string
ProjectorPaths []string
Template string
@@ -49,12 +49,156 @@ type Model struct {
Digest string
Size int64
Options map[string]interface{}
Messages []Message
}
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
type PromptVars struct {
System string
Prompt string
Response string
First bool
}
// extractParts extracts the parts of the template before and after the {{.Response}} node.
func extractParts(tmplStr string) (pre string, post string, err error) {
tmpl, err := template.New("").Parse(tmplStr)
if err != nil {
return "", "", err
}
var foundResponse bool
for _, node := range tmpl.Tree.Root.Nodes {
if node.Type() == parse.NodeAction && node.String() == "{{.Response}}" {
foundResponse = true
}
if !foundResponse {
pre += node.String()
} else {
post += node.String()
}
}
return pre, post, nil
}
func Prompt(promptTemplate string, p PromptVars) (string, error) {
var prompt strings.Builder
// Use the "missingkey=zero" option to handle missing variables without panicking
tmpl, err := template.New("").Option("missingkey=zero").Parse(promptTemplate)
if err != nil {
return "", err
}
vars := map[string]any{
"System": p.System,
"Prompt": p.Prompt,
"Response": p.Response,
"First": p.First,
}
var sb strings.Builder
if err := tmpl.Execute(&sb, vars); err != nil {
return "", err
}
prompt.WriteString(sb.String())
if !strings.Contains(prompt.String(), p.Response) {
// if the response is not in the prompt template, append it to the end
prompt.WriteString(p.Response)
}
return prompt.String(), nil
}
// PreResponsePrompt returns the prompt before the response tag
func (m *Model) PreResponsePrompt(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
pre, _, err := extractParts(m.Template)
if err != nil {
return "", err
}
return Prompt(pre, p)
}
// PostResponseTemplate returns the template after the response tag
func (m *Model) PostResponseTemplate(p PromptVars) (string, error) {
if p.System == "" {
// use the default system prompt for this model if one is not specified
p.System = m.System
}
_, post, err := extractParts(m.Template)
if err != nil {
return "", err
}
if post == "" {
// if there is no post-response template, return the provided response
return p.Response, nil
}
return Prompt(post, p)
}
func (m *Model) ChatPrompt(msgs []api.Message) (string, []api.ImageData, error) {
// build the prompt from the list of messages
var prompt strings.Builder
var currentImages []api.ImageData
currentVars := PromptVars{
First: true,
System: m.System,
}
writePrompt := func() error {
p, err := Prompt(m.Template, currentVars)
if err != nil {
return err
}
prompt.WriteString(p)
currentVars = PromptVars{}
return nil
}
for _, msg := range msgs {
switch strings.ToLower(msg.Role) {
case "system":
if currentVars.System != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
}
currentVars.System = msg.Content
case "user":
if currentVars.Prompt != "" {
if err := writePrompt(); err != nil {
return "", nil, err
}
}
currentVars.Prompt = msg.Content
currentImages = msg.Images
case "assistant":
currentVars.Response = msg.Content
if err := writePrompt(); err != nil {
return "", nil, err
}
default:
return "", nil, fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// Append the last set of vars if they are non-empty
if currentVars.Prompt != "" || currentVars.System != "" {
p, err := m.PreResponsePrompt(currentVars)
if err != nil {
return "", nil, fmt.Errorf("pre-response template: %w", err)
}
prompt.WriteString(p)
}
return prompt.String(), currentImages, nil
}
type ManifestV2 struct {
@@ -188,11 +332,11 @@ func GetModel(name string) (*Model, error) {
switch layer.MediaType {
case "application/vnd.ollama.image.model":
model.ModelPath = filename
model.ParentModel = layer.From
model.OriginalModel = layer.From
case "application/vnd.ollama.image.embed":
// Deprecated in versions > 0.1.2
// TODO: remove this warning in a future version
slog.Info("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
log.Print("WARNING: model contains embeddings, but embeddings in modelfiles have been deprecated and will be ignored.")
case "application/vnd.ollama.image.adapter":
model.AdapterPaths = append(model.AdapterPaths, filename)
case "application/vnd.ollama.image.projector":
@@ -229,16 +373,6 @@ func GetModel(name string) (*Model, error) {
if err = json.NewDecoder(params).Decode(&model.Options); err != nil {
return nil, err
}
case "application/vnd.ollama.image.messages":
msgs, err := os.Open(filename)
if err != nil {
return nil, err
}
defer msgs.Close()
if err = json.NewDecoder(msgs).Decode(&model.Messages); err != nil {
return nil, err
}
case "application/vnd.ollama.image.license":
bts, err := os.ReadFile(filename)
if err != nil {
@@ -277,13 +411,6 @@ func realpath(mfDir, from string) string {
}
func CreateModel(ctx context.Context, name, modelFileDir string, commands []parser.Command, fn func(resp api.ProgressResponse)) error {
deleteMap := make(map[string]struct{})
if manifest, _, err := GetManifest(ParseModelPath(name)); err == nil {
for _, layer := range append(manifest.Layers, manifest.Config) {
deleteMap[layer.Digest] = struct{}{}
}
}
config := ConfigV2{
OS: "linux",
Architecture: "amd64",
@@ -292,13 +419,15 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
},
}
deleteMap := make(map[string]struct{})
var layers Layers
messages := []string{}
params := make(map[string][]string)
fromParams := make(map[string]any)
for _, c := range commands {
log.Printf("[%s] - %s", c.Name, c.Args)
mediatype := fmt.Sprintf("application/vnd.ollama.image.%s", c.Name)
switch c.Name {
@@ -472,37 +601,11 @@ func CreateModel(ctx context.Context, name, modelFileDir string, commands []pars
}
layers.Replace(layer)
case "message":
messages = append(messages, c.Args)
default:
params[c.Name] = append(params[c.Name], c.Args)
}
}
if len(messages) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
msgs := make([]api.Message, 0)
for _, m := range messages {
// todo: handle images
msg := strings.SplitN(m, ": ", 2)
msgs = append(msgs, api.Message{Role: msg[0], Content: msg[1]})
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(msgs); err != nil {
return err
}
layer, err := NewLayer(&b, "application/vnd.ollama.image.messages")
if err != nil {
return err
}
layers.Replace(layer)
}
if len(params) > 0 {
fn(api.ProgressResponse{Status: "creating parameters layer"})
@@ -644,7 +747,6 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{},
// save (i.e. delete from the deleteMap) any files used in other manifests
manifest, _, err := GetManifest(fmp)
if err != nil {
// nolint: nilerr
return nil
}
@@ -664,16 +766,16 @@ func deleteUnusedLayers(skipModelPath *ModelPath, deleteMap map[string]struct{},
for k := range deleteMap {
fp, err := GetBlobsPath(k)
if err != nil {
slog.Info(fmt.Sprintf("couldn't get file path for '%s': %v", k, err))
log.Printf("couldn't get file path for '%s': %v", k, err)
continue
}
if !dryRun {
if err := os.Remove(fp); err != nil {
slog.Info(fmt.Sprintf("couldn't remove file '%s': %v", fp, err))
log.Printf("couldn't remove file '%s': %v", fp, err)
continue
}
} else {
slog.Info(fmt.Sprintf("wanted to remove: %s", fp))
log.Printf("wanted to remove: %s", fp)
}
}
@@ -689,7 +791,7 @@ func PruneLayers() error {
blobs, err := os.ReadDir(p)
if err != nil {
slog.Info(fmt.Sprintf("couldn't read dir '%s': %v", p, err))
log.Printf("couldn't read dir '%s': %v", p, err)
return err
}
@@ -703,14 +805,14 @@ func PruneLayers() error {
}
}
slog.Info(fmt.Sprintf("total blobs: %d", len(deleteMap)))
log.Printf("total blobs: %d", len(deleteMap))
err = deleteUnusedLayers(nil, deleteMap, false)
if err != nil {
return err
}
slog.Info(fmt.Sprintf("total unused blobs removed: %d", len(deleteMap)))
log.Printf("total unused blobs removed: %d", len(deleteMap))
return nil
}
@@ -772,7 +874,7 @@ func DeleteModel(name string) error {
}
err = os.Remove(fp)
if err != nil {
slog.Info(fmt.Sprintf("couldn't remove manifest file '%s': %v", fp, err))
log.Printf("couldn't remove manifest file '%s': %v", fp, err)
return err
}
@@ -799,8 +901,8 @@ func ShowModelfile(model *Model) (string, error) {
mt.Model = model
mt.From = model.ModelPath
if model.ParentModel != "" {
mt.From = model.ParentModel
if model.OriginalModel != "" {
mt.From = model.OriginalModel
}
modelFile := `# Modelfile generated by "ollama show"
@@ -826,14 +928,14 @@ PARAMETER {{ $k }} {{ printf "%#v" $parameter }}
tmpl, err := template.New("").Parse(modelFile)
if err != nil {
slog.Info(fmt.Sprintf("error parsing template: %q", err))
log.Printf("error parsing template: %q", err)
return "", err
}
var buf bytes.Buffer
if err = tmpl.Execute(&buf, mt); err != nil {
slog.Info(fmt.Sprintf("error executing template: %q", err))
log.Printf("error executing template: %q", err)
return "", err
}
@@ -860,7 +962,7 @@ func PushModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
for _, layer := range layers {
if err := uploadBlob(ctx, mp, layer, regOpts, fn); err != nil {
slog.Info(fmt.Sprintf("error uploading blob: %v", err))
log.Printf("error uploading blob: %v", err)
if errors.Is(err, errUnauthorized) {
return fmt.Errorf("unable to push %s, make sure this namespace exists and you are authorized to push to it", ParseModelPath(name).GetNamespaceRepository())
}
@@ -955,7 +1057,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
}
if err := os.Remove(fp); err != nil {
// log this, but return the original error
slog.Info(fmt.Sprintf("couldn't remove file with digest mismatch '%s': %v", fp, err))
log.Printf("couldn't remove file with digest mismatch '%s': %v", fp, err)
}
}
return err
@@ -979,7 +1081,7 @@ func PullModel(ctx context.Context, name string, regOpts *RegistryOptions, fn fu
err = os.WriteFile(fp, manifestJSON, 0o644)
if err != nil {
slog.Info(fmt.Sprintf("couldn't write to %s", fp))
log.Printf("couldn't write to %s", fp)
return err
}
@@ -1029,46 +1131,49 @@ func GetSHA256Digest(r io.Reader) (string, int64) {
var errUnauthorized = fmt.Errorf("unauthorized")
func makeRequestWithRetry(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.ReadSeeker, regOpts *RegistryOptions) (*http.Response, error) {
for i := 0; i < 2; i++ {
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
if !errors.Is(err, context.Canceled) {
slog.Info(fmt.Sprintf("request failed: %v", err))
}
return nil, err
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if err != nil {
if !errors.Is(err, context.Canceled) {
log.Printf("request failed: %v", err)
}
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
return nil, err
}
switch {
case resp.StatusCode == http.StatusUnauthorized:
// Handle authentication error with one retry
auth := resp.Header.Get("www-authenticate")
authRedir := ParseAuthRedirectString(auth)
token, err := getAuthToken(ctx, authRedir)
if err != nil {
return nil, err
}
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
regOpts.Token = token
if body != nil {
_, err = body.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
}
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
default:
return resp, nil
}
resp, err := makeRequest(ctx, method, requestURL, headers, body, regOpts)
if resp.StatusCode == http.StatusUnauthorized {
return nil, errUnauthorized
}
return resp, err
case resp.StatusCode == http.StatusNotFound:
return nil, os.ErrNotExist
case resp.StatusCode >= http.StatusBadRequest:
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("%d: %s", resp.StatusCode, err)
}
return nil, fmt.Errorf("%d: %s", resp.StatusCode, responseBody)
}
return nil, errUnauthorized
return resp, nil
}
func makeRequest(ctx context.Context, method string, requestURL *url.URL, headers http.Header, body io.Reader, regOpts *RegistryOptions) (*http.Response, error) {

347
server/images_test.go Normal file
View File

@@ -0,0 +1,347 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "System Prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "System Prompt with Response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "Conditional Logic Nodes",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
First: true,
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
Response: "I don't know.",
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := Prompt(tt.template, tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("Prompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Prompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
System: "You are a Wizard.",
Prompt: "What are the potion ingredients?",
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] ",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Prompt: "What are the potion ingredients?",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\n",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PreResponsePrompt(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PreResponsePrompt() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
vars PromptVars
want string
wantErr bool
}{
{
name: "No Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.",
},
{
name: "Response in Template with Trailing Formatting",
template: "<|im_start|>user\n{{ .Prompt }}<|im_end|><|im_start|>assistant\n{{ .Response }}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
{
name: "Response in Template with Alternative Formatting",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
vars: PromptVars{
Response: "I don't know.",
},
want: "I don't know.<|im_end|>",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
got, err := m.PostResponseTemplate(tt.vars)
if (err != nil) != tt.wantErr {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("PostResponseTemplate() got = %v, want %v", got, tt.want)
}
})
}
}
func TestModel_PreResponsePrompt_PostResponsePrompt(t *testing.T) {
tests := []struct {
name string
template string
preVars PromptVars
postVars PromptVars
want string
wantErr bool
}{
{
name: "Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n{{.Response}}<|im_end|>",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Sugar.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSugar.<|im_end|>",
},
{
name: "No Response in Template",
template: "<|im_start|>user\n{{.Prompt}}<|im_end|><|im_start|>assistant\n",
preVars: PromptVars{
Prompt: "What are the potion ingredients?",
},
postVars: PromptVars{
Prompt: "What are the potion ingredients?",
Response: "Spice.",
},
want: "<|im_start|>user\nWhat are the potion ingredients?<|im_end|><|im_start|>assistant\nSpice.",
},
}
for _, tt := range tests {
m := Model{Template: tt.template}
t.Run(tt.name, func(t *testing.T) {
pre, err := m.PreResponsePrompt(tt.preVars)
if (err != nil) != tt.wantErr {
t.Errorf("PreResponsePrompt() error = %v, wantErr %v", err, tt.wantErr)
return
}
post, err := m.PostResponseTemplate(tt.postVars)
if err != nil {
t.Errorf("PostResponseTemplate() error = %v, wantErr %v", err, tt.wantErr)
return
}
result := pre + post
if result != tt.want {
t.Errorf("Prompt() got = %v, want %v", result, tt.want)
}
})
}
}
func TestChat(t *testing.T) {
tests := []struct {
name string
template string
msgs []api.Message
want string
wantErr string
}{
{
name: "Single Message",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "First Message",
template: "[INST] {{if .First}}Hello!{{end}} {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "eye of newt",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] Hello! You are a Wizard. What are the potion ingredients? [/INST]eye of newt[INST] Anything else? [/INST]",
},
{
name: "Message History",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "system",
Content: "You are a Wizard.",
},
{
Role: "user",
Content: "What are the potion ingredients?",
},
{
Role: "assistant",
Content: "sugar",
},
{
Role: "user",
Content: "Anything else?",
},
},
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]sugar[INST] Anything else? [/INST]",
},
{
name: "Assistant Only",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
msgs: []api.Message{
{
Role: "assistant",
Content: "everything nice",
},
},
want: "[INST] [/INST]everything nice",
},
{
name: "Invalid Role",
msgs: []api.Message{
{
Role: "not-a-role",
Content: "howdy",
},
},
wantErr: "invalid role: not-a-role",
},
}
for _, tt := range tests {
m := Model{
Template: tt.template,
}
t.Run(tt.name, func(t *testing.T) {
got, _, err := m.ChatPrompt(tt.msgs)
if tt.wantErr != "" {
if err == nil {
t.Errorf("ChatPrompt() expected error, got nil")
}
if !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("ChatPrompt() error = %v, wantErr %v", err, tt.wantErr)
}
}
if got != tt.want {
t.Errorf("ChatPrompt() got = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -26,9 +26,9 @@ func WriteManifest(name string, config *Layer, layers []*Layer) error {
return err
}
if err := os.MkdirAll(filepath.Dir(manifestPath), 0o755); err != nil {
if err := os.MkdirAll(filepath.Dir(manifestPath), 0755); err != nil {
return err
}
return os.WriteFile(manifestPath, b.Bytes(), 0o644)
return os.WriteFile(manifestPath, b.Bytes(), 0644)
}

View File

@@ -46,8 +46,7 @@ func ParseModelPath(name string) ModelPath {
name = after
}
name = strings.ReplaceAll(name, string(os.PathSeparator), "/")
parts := strings.Split(name, "/")
parts := strings.Split(name, string(os.PathSeparator))
switch len(parts) {
case 3:
mp.Registry = parts[0]

View File

@@ -1,224 +0,0 @@
package server
import (
"fmt"
"log/slog"
"strings"
"text/template"
"text/template/parse"
"github.com/jmorganca/ollama/api"
)
// isResponseNode checks if the node contains .Response
func isResponseNode(node *parse.ActionNode) bool {
for _, cmd := range node.Pipe.Cmds {
for _, arg := range cmd.Args {
if fieldNode, ok := arg.(*parse.FieldNode); ok && len(fieldNode.Ident) > 0 {
if fieldNode.Ident[0] == "Response" {
return true
}
}
}
}
return false
}
// formatTemplateForResponse formats the template AST to:
// 1. remove all nodes after the first .Response (if generate=true)
// 2. add a .Response node to the end if it doesn't exist
// TODO(jmorganca): this should recursively cut the template before the first .Response
func formatTemplateForResponse(tmpl *template.Template, generate bool) {
var found bool
for i, node := range tmpl.Tree.Root.Nodes {
if actionNode, ok := node.(*parse.ActionNode); ok {
if isResponseNode(actionNode) {
found = true
if generate {
tmpl.Tree.Root.Nodes = tmpl.Tree.Root.Nodes[:i+1]
break
}
}
}
}
if !found {
// add the response node if it doesn't exist
responseFieldNode := &parse.FieldNode{NodeType: parse.NodeField, Ident: []string{"Response"}}
responsePipeNode := &parse.PipeNode{NodeType: parse.NodePipe, Cmds: []*parse.CommandNode{{NodeType: parse.NodeCommand, Args: []parse.Node{responseFieldNode}}}}
responseActionNode := &parse.ActionNode{NodeType: parse.NodeAction, Pipe: responsePipeNode}
tmpl.Tree.Root.Nodes = append(tmpl.Tree.Root.Nodes, responseActionNode)
}
}
// Prompt renders a prompt from a template. If generate is set to true,
// the response and parts of the template following it are not rendered
func Prompt(tmpl, system, prompt, response string, generate bool) (string, error) {
parsed, err := template.New("").Option("missingkey=zero").Parse(tmpl)
if err != nil {
return "", err
}
formatTemplateForResponse(parsed, generate)
vars := map[string]any{
"System": system,
"Prompt": prompt,
"Response": response,
}
var sb strings.Builder
if err := parsed.Execute(&sb, vars); err != nil {
return "", err
}
return sb.String(), nil
}
func countTokens(tmpl string, system string, prompt string, response string, encode func(string) ([]int, error)) (int, error) {
rendered, err := Prompt(tmpl, system, prompt, response, false)
if err != nil {
return 0, err
}
tokens, err := encode(rendered)
if err != nil {
slog.Error("failed to encode prompt", "err", err)
return 0, err
}
return len(tokens), err
}
// ChatPrompt builds up a prompt from a series of messages, truncating based on context window size
func ChatPrompt(tmpl string, system string, messages []api.Message, window int, encode func(string) ([]int, error)) (string, error) {
type prompt struct {
System string
Prompt string
Response string
images []int
tokens int
}
var p prompt
// Set the first system prompt to the model's system prompt
if system != "" {
p.System = system
}
// iterate through messages to build up {system,user,response} prompts
var imgId int
var prompts []prompt
for _, msg := range messages {
switch strings.ToLower(msg.Role) {
case "system":
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.System = msg.Content
case "user":
if p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Prompt = msg.Content
for range msg.Images {
p.Prompt += fmt.Sprintf(" [img-%d]", imgId)
p.images = append(p.images, imgId)
imgId += 1
}
case "assistant":
if p.Response != "" {
prompts = append(prompts, p)
p = prompt{}
}
p.Response = msg.Content
default:
return "", fmt.Errorf("invalid role: %s, role must be one of [system, user, assistant]", msg.Role)
}
}
// add final prompt
if p.System != "" || p.Prompt != "" || p.Response != "" {
prompts = append(prompts, p)
}
// calculate token lengths for each prompt, estimating 768 tokens per images
for i, p := range prompts {
tokens, err := countTokens(tmpl, p.System, p.Prompt, p.Response, encode)
if err != nil {
return "", err
}
prompts[i].tokens = tokens + len(prompts[i].images)*768
}
// truncate images and prompts starting from the beginning of the list
// until either one prompt remains or the total tokens fits the context window
// TODO (jmorganca): this doesn't account for the context window room required for the response
for {
var required int
for _, p := range prompts {
required += p.tokens
}
required += 1 // for bos token
if required <= window {
slog.Debug("prompt now fits in context window", "required", required, "window", window)
break
}
prompt := &prompts[0]
if len(prompt.images) > 1 {
img := prompt.images[0]
slog.Debug("prompt longer than context window, removing image", "id", img, "required", required, "window", window)
prompt.images = prompt.images[1:]
prompt.Prompt = strings.Replace(prompt.Prompt, fmt.Sprintf(" [img-%d]", img), "", 1)
prompt.tokens -= 768
continue
}
if len(prompts) > 1 {
slog.Debug("required tokens longer than context window, removing first prompt", "prompt", prompts[0].tokens, "required", required, "window", window)
system := prompt.System
prompts = prompts[1:]
if system != "" && prompts[0].System == "" {
prompts[0].System = system
tokens, err := countTokens(tmpl, prompts[0].System, prompts[0].Prompt, prompts[0].Response, encode)
if err != nil {
return "", err
}
prompts[0].tokens = tokens + len(prompts[0].images)*768
}
continue
}
// stop truncating if there's only one prompt left
break
}
var sb strings.Builder
for i, p := range prompts {
// last prompt should leave the response unrendered (for completion)
rendered, err := Prompt(tmpl, p.System, p.Prompt, p.Response, i == len(prompts)-1)
if err != nil {
return "", err
}
sb.WriteString(rendered)
}
return sb.String(), nil
}

View File

@@ -1,234 +0,0 @@
package server
import (
"strings"
"testing"
"github.com/jmorganca/ollama/api"
)
func TestPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
prompt string
response string
generate bool
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]",
},
{
name: "implicit response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST]",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST]I don't know.",
},
{
name: "response",
template: "[INST] {{ .System }} {{ .Prompt }} [/INST] {{ .Response }}",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "[INST] You are a Wizard. What are the potion ingredients? [/INST] I don't know.",
},
{
name: "cut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
generate: true,
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.",
},
{
name: "nocut",
template: "<system>{{ .System }}</system><user>{{ .Prompt }}</user><assistant>{{ .Response }}</assistant>",
system: "You are a Wizard.",
prompt: "What are the potion ingredients?",
response: "I don't know.",
want: "<system>You are a Wizard.</system><user>What are the potion ingredients?</user><assistant>I don't know.</assistant>",
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := Prompt(tc.template, tc.system, tc.prompt, tc.response, tc.generate)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}
func TestChatPrompt(t *testing.T) {
tests := []struct {
name string
template string
system string
messages []api.Message
window int
want string
}{
{
name: "simple prompt",
template: "[INST] {{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] Hello [/INST]",
},
{
name: "with default system message",
system: "You are a Wizard.",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with system message",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]",
},
{
name: "with response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST] I am?",
},
{
name: "with implicit response",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST]",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> Hello [/INST]I am?",
},
{
name: "with conversation",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "What are the potion ingredients?"},
{Role: "assistant", Content: "sugar"},
{Role: "user", Content: "Anything else?"},
},
window: 1024,
want: "[INST] <<SYS>>You are a Wizard.<</SYS>> What are the potion ingredients? [/INST] sugar [INST] Anything else? [/INST] ",
},
{
name: "with truncation",
template: "{{ .System }} {{ .Prompt }} {{ .Response }} ",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello"},
{Role: "assistant", Content: "I am?"},
{Role: "user", Content: "Why is the sky blue?"},
{Role: "assistant", Content: "The sky is blue from rayleigh scattering"},
},
window: 10,
want: "You are a Wizard. Why is the sky blue? The sky is blue from rayleigh scattering",
},
{
name: "images",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("base64")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-0]",
},
{
name: "images truncated",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "system", Content: "You are a Wizard."},
{Role: "user", Content: "Hello", Images: []api.ImageData{[]byte("img1"), []byte("img2")}},
},
window: 1024,
want: "You are a Wizard. Hello [img-1]",
},
{
name: "empty list",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "",
},
{
name: "empty list default system",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty user message",
system: "You are a Wizard.",
template: "{{ .System }} {{ .Prompt }}",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "You are a Wizard. ",
},
{
name: "empty prompt",
template: "[INST] {{ if .System }}<<SYS>>{{ .System }}<</SYS>> {{ end }}{{ .Prompt }} [/INST] {{ .Response }} ",
messages: []api.Message{
{Role: "user", Content: ""},
},
window: 1024,
want: "",
},
}
encode := func(s string) ([]int, error) {
words := strings.Fields(s)
return make([]int, len(words)), nil
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got, err := ChatPrompt(tc.template, tc.system, tc.messages, tc.window, encode)
if err != nil {
t.Errorf("error = %v", err)
}
if got != tc.want {
t.Errorf("got = %v, want %v", got, tc.want)
}
})
}
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"io"
"io/fs"
"log/slog"
"log"
"net"
"net/http"
"os"
@@ -15,6 +15,7 @@ import (
"path/filepath"
"reflect"
"runtime"
"strconv"
"strings"
"sync"
"syscall"
@@ -22,12 +23,10 @@ import (
"github.com/gin-contrib/cors"
"github.com/gin-gonic/gin"
"golang.org/x/exp/slices"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/gpu"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/openai"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
@@ -75,7 +74,7 @@ func load(c *gin.Context, model *Model, opts api.Options, sessionDuration time.D
if needLoad {
if loaded.runner != nil {
slog.Info("changing loaded model")
log.Println("changing loaded model")
loaded.runner.Close()
loaded.runner = nil
loaded.Model = nil
@@ -137,12 +136,6 @@ func modelOptions(model *Model, requestOpts map[string]interface{}) (api.Options
return opts, nil
}
func isSupportedImageType(image []byte) bool {
contentType := http.DetectContentType(image)
allowedTypes := []string{"image/jpeg", "image/jpg", "image/png"}
return slices.Contains(allowedTypes, contentType)
}
func GenerateHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
@@ -173,13 +166,6 @@ func GenerateHandler(c *gin.Context) {
return
}
for _, img := range req.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
}
model, err := GetModel(req.Model)
if err != nil {
var pErr *fs.PathError
@@ -201,79 +187,61 @@ func GenerateHandler(c *gin.Context) {
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
// note: for a short while template was used in lieu
// of `raw` mode so we need to check for it too
if req.Prompt == "" && req.Template == "" && req.System == "" {
c.JSON(http.StatusOK, api.GenerateResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
})
Done: true})
return
}
checkpointLoaded := time.Now()
var prompt string
var promptVars PromptVars
switch {
case req.Raw:
prompt = req.Prompt
case req.Prompt != "":
if req.Template == "" {
req.Template = model.Template
if req.Template != "" {
// override the default model template
model.Template = req.Template
}
if req.System == "" {
req.System = model.System
}
slog.Debug("generate handler", "prompt", req.Prompt)
slog.Debug("generate handler", "template", req.Template)
slog.Debug("generate handler", "system", req.System)
var sb strings.Builder
var rebuild strings.Builder
if req.Context != nil {
prev, err := loaded.runner.Decode(c.Request.Context(), req.Context)
// TODO: context is deprecated, at some point the context logic within this conditional should be removed
prevCtx, err := loaded.runner.Decode(c.Request.Context(), req.Context)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(prev)
// Remove leading spaces from prevCtx if present
prevCtx = strings.TrimPrefix(prevCtx, " ")
rebuild.WriteString(prevCtx)
}
// write image tags
// TODO: limit the number of images to fit in the context similar to the chat endpoint
for i := range req.Images {
req.Prompt += fmt.Sprintf(" [img-%d]", i)
promptVars = PromptVars{
System: req.System,
Prompt: req.Prompt,
First: len(req.Context) == 0,
}
p, err := Prompt(req.Template, req.System, req.Prompt, "", true)
p, err := model.PreResponsePrompt(promptVars)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
sb.WriteString(p)
prompt = sb.String()
rebuild.WriteString(p)
prompt = rebuild.String()
}
slog.Debug("generate handler", "prompt", prompt)
ch := make(chan any)
var generated strings.Builder
go func() {
@@ -308,39 +276,30 @@ func GenerateHandler(c *gin.Context) {
resp.LoadDuration = checkpointLoaded.Sub(checkpointStart)
if !req.Raw {
p, err := Prompt(req.Template, req.System, req.Prompt, generated.String(), false)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// TODO (jmorganca): encode() should not strip special tokens
tokens, err := loaded.runner.Encode(c.Request.Context(), p)
// append the generated text to the history and template it if needed
promptVars.Response = generated.String()
result, err := model.PostResponseTemplate(promptVars)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = append(req.Context, tokens...)
embd, err := loaded.runner.Encode(c.Request.Context(), prompt+result)
if err != nil {
ch <- gin.H{"error": err.Error()}
return
}
resp.Context = embd
}
}
ch <- resp
}
var images []llm.ImageData
for i := range req.Images {
images = append(images, llm.ImageData{
ID: i,
Data: req.Images[i],
})
}
// Start prediction
predictReq := llm.PredictOpts{
Prompt: prompt,
Format: req.Format,
Images: images,
Images: req.Images,
Options: opts,
}
if err := loaded.runner.Predict(c.Request.Context(), predictReq, fn); err != nil {
@@ -419,14 +378,7 @@ func EmbeddingHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -439,7 +391,7 @@ func EmbeddingHandler(c *gin.Context) {
embedding, err := loaded.runner.Embedding(c.Request.Context(), req.Prompt)
if err != nil {
slog.Info(fmt.Sprintf("embedding generation failed: %v", err))
log.Printf("embedding generation failed: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate embedding"})
return
}
@@ -462,13 +414,8 @@ func PullModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@@ -486,7 +433,7 @@ func PullModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PullModel(ctx, model, regOpts, fn); err != nil {
if err := PullModel(ctx, req.Name, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -511,13 +458,8 @@ func PushModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
@@ -535,7 +477,7 @@ func PushModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := PushModel(ctx, model, regOpts, fn); err != nil {
if err := PushModel(ctx, req.Name, regOpts, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -560,17 +502,12 @@ func CreateModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if err := ParseModelPath(model).Validate(); err != nil {
if err := ParseModelPath(req.Name).Validate(); err != nil {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
@@ -608,7 +545,7 @@ func CreateModelHandler(c *gin.Context) {
ctx, cancel := context.WithCancel(c.Request.Context())
defer cancel()
if err := CreateModel(ctx, model, filepath.Dir(req.Path), commands, fn); err != nil {
if err := CreateModel(ctx, req.Name, filepath.Dir(req.Path), commands, fn); err != nil {
ch <- gin.H{"error": err.Error()}
}
}()
@@ -633,19 +570,14 @@ func DeleteModelHandler(c *gin.Context) {
return
}
var model string
if req.Model != "" {
model = req.Model
} else if req.Name != "" {
model = req.Name
} else {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
if req.Name == "" {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if err := DeleteModel(model); err != nil {
if err := DeleteModel(req.Name); err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@@ -678,19 +610,21 @@ func ShowModelHandler(c *gin.Context) {
return
}
if req.Model != "" {
// noop
} else if req.Name != "" {
req.Model = req.Name
} else {
switch {
case req.Model == "" && req.Name == "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "model is required"})
return
case req.Model != "" && req.Name != "":
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "both model and name are set"})
return
case req.Model == "" && req.Name != "":
req.Model = req.Name
}
resp, err := GetModelInfo(req)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Model)})
c.JSON(http.StatusNotFound, gin.H{"error": fmt.Sprintf("model '%s' not found", req.Name)})
} else {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
}
@@ -707,7 +641,6 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
}
modelDetails := api.ModelDetails{
ParentModel: model.ParentModel,
Format: model.Config.ModelFormat,
Family: model.Config.ModelFamily,
Families: model.Config.ModelFamilies,
@@ -723,29 +656,38 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
model.Template = req.Template
}
msgs := make([]api.Message, 0)
for _, msg := range model.Messages {
msgs = append(msgs, api.Message{Role: msg.Role, Content: msg.Content})
}
resp := &api.ShowResponse{
License: strings.Join(model.License, "\n"),
System: model.System,
Template: model.Template,
Details: modelDetails,
Messages: msgs,
}
var params []string
cs := 30
for k, v := range model.Options {
switch val := v.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, val))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(val)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(val, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(val)))
case []interface{}:
for _, nv := range val {
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, nv))
switch nval := nv.(type) {
case string:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, nval))
case int:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.Itoa(nval)))
case float64:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatFloat(nval, 'f', 0, 64)))
case bool:
params = append(params, fmt.Sprintf("%-*s %s", cs, k, strconv.FormatBool(nval)))
}
}
default:
params = append(params, fmt.Sprintf("%-*s %#v", cs, k, v))
}
}
resp.Parameters = strings.Join(params, "\n")
@@ -768,7 +710,7 @@ func GetModelInfo(req api.ShowRequest) (*api.ShowResponse, error) {
func ListModelsHandler(c *gin.Context) {
models := make([]api.ModelResponse, 0)
manifestsPath, err := GetManifestPath()
fp, err := GetManifestPath()
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
@@ -789,7 +731,6 @@ func ListModelsHandler(c *gin.Context) {
}
return api.ModelResponse{
Model: model.ShortName,
Name: model.ShortName,
Size: model.Size,
Digest: model.Digest,
@@ -799,15 +740,13 @@ func ListModelsHandler(c *gin.Context) {
walkFunc := func(path string, info os.FileInfo, _ error) error {
if !info.IsDir() {
path, tag := filepath.Split(path)
model := strings.Trim(strings.TrimPrefix(path, manifestsPath), string(os.PathSeparator))
modelPath := strings.Join([]string{model, tag}, ":")
canonicalModelPath := strings.ReplaceAll(modelPath, string(os.PathSeparator), "/")
dir, file := filepath.Split(path)
dir = strings.Trim(strings.TrimPrefix(dir, fp), string(os.PathSeparator))
tag := strings.Join([]string{dir, file}, ":")
resp, err := modelResponse(canonicalModelPath)
resp, err := modelResponse(tag)
if err != nil {
slog.Info(fmt.Sprintf("skipping file: %s", canonicalModelPath))
// nolint: nilerr
log.Printf("skipping file: %s", fp)
return nil
}
@@ -818,7 +757,7 @@ func ListModelsHandler(c *gin.Context) {
return nil
}
if err := filepath.Walk(manifestsPath, walkFunc); err != nil {
if err := filepath.Walk(fp, walkFunc); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
@@ -951,9 +890,6 @@ func (s *Server) GenerateRoutes() http.Handler {
r.POST("/api/blobs/:digest", CreateBlobHandler)
r.HEAD("/api/blobs/:digest", HeadBlobHandler)
// Compatibility endpoints
r.POST("/v1/chat/completions", openai.Middleware(), ChatHandler)
for _, method := range []string{http.MethodGet, http.MethodHead} {
r.Handle(method, "/", func(c *gin.Context) {
c.String(http.StatusOK, "Ollama is running")
@@ -969,26 +905,6 @@ func (s *Server) GenerateRoutes() http.Handler {
}
func Serve(ln net.Listener) error {
level := slog.LevelInfo
if debug := os.Getenv("OLLAMA_DEBUG"); debug != "" {
level = slog.LevelDebug
}
handler := slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
Level: level,
AddSource: true,
ReplaceAttr: func(_ []string, attr slog.Attr) slog.Attr {
if attr.Key == slog.SourceKey {
source := attr.Value.Any().(*slog.Source)
source.File = filepath.Base(source.File)
}
return attr
},
})
slog.SetDefault(slog.New(handler))
if noprune := os.Getenv("OLLAMA_NOPRUNE"); noprune == "" {
// clean up unused layers and manifests
if err := PruneLayers(); err != nil {
@@ -1011,7 +927,7 @@ func Serve(ln net.Listener) error {
}
r := s.GenerateRoutes()
slog.Info(fmt.Sprintf("Listening on %s (version %s)", ln.Addr(), version.Version))
log.Printf("Listening on %s (version %s)", ln.Addr(), version.Version)
srvr := &http.Server{
Handler: r,
}
@@ -1034,7 +950,7 @@ func Serve(ln net.Listener) error {
if runtime.GOOS == "linux" { // TODO - windows too
// check compatibility to log warnings
if _, err := gpu.CheckVRAM(); err != nil {
slog.Info(err.Error())
log.Print(err.Error())
}
}
@@ -1076,14 +992,14 @@ func streamResponse(c *gin.Context, ch chan any) {
bts, err := json.Marshal(val)
if err != nil {
slog.Info(fmt.Sprintf("streamResponse: json.Marshal failed with %s", err))
log.Printf("streamResponse: json.Marshal failed with %s", err)
return false
}
// Delineate chunks with new-line delimiter
bts = append(bts, '\n')
if _, err := w.Write(bts); err != nil {
slog.Info(fmt.Sprintf("streamResponse: w.Write failed with %s", err))
log.Printf("streamResponse: w.Write failed with %s", err)
return false
}
@@ -1091,20 +1007,6 @@ func streamResponse(c *gin.Context, ch chan any) {
})
}
// ChatPrompt builds up a prompt from a series of messages for the currently `loaded` model
func chatPrompt(ctx context.Context, messages []api.Message) (string, error) {
encode := func(s string) ([]int, error) {
return loaded.runner.Encode(ctx, s)
}
prompt, err := ChatPrompt(loaded.Model.Template, loaded.Model.System, messages, loaded.Options.NumCtx, encode)
if err != nil {
return "", err
}
return prompt, nil
}
func ChatHandler(c *gin.Context) {
loaded.mu.Lock()
defer loaded.mu.Unlock()
@@ -1152,58 +1054,26 @@ func ChatHandler(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
var sessionDuration time.Duration
if req.KeepAlive == nil {
sessionDuration = defaultSessionDuration
} else {
sessionDuration = req.KeepAlive.Duration
}
sessionDuration := defaultSessionDuration
if err := load(c, model, opts, sessionDuration); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 {
c.JSON(http.StatusOK, api.ChatResponse{CreatedAt: time.Now().UTC(), Model: req.Model, Done: true, Message: api.Message{Role: "assistant"}})
return
}
checkpointLoaded := time.Now()
prompt, err := chatPrompt(c.Request.Context(), req.Messages)
prompt, images, err := model.ChatPrompt(req.Messages)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// an empty request loads the model
if len(req.Messages) == 0 || prompt == "" {
resp := api.ChatResponse{
CreatedAt: time.Now().UTC(),
Model: req.Model,
Done: true,
Message: api.Message{Role: "assistant"},
}
c.JSON(http.StatusOK, resp)
return
}
// only send images that are in the prompt
var i int
var images []llm.ImageData
for _, m := range req.Messages {
for _, img := range m.Images {
if !isSupportedImageType(img) {
c.AbortWithStatusJSON(http.StatusBadRequest, gin.H{"error": "unsupported image format"})
return
}
if strings.Contains(prompt, fmt.Sprintf("[img-%d]", i)) {
images = append(images, llm.ImageData{Data: img, ID: i})
}
i += 1
}
}
slog.Debug("chat handler", "prompt", prompt, "images", len(images))
ch := make(chan any)
go func() {

View File

@@ -9,14 +9,12 @@ import (
"net/http"
"net/http/httptest"
"os"
"sort"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/jmorganca/ollama/api"
"github.com/jmorganca/ollama/llm"
"github.com/jmorganca/ollama/parser"
"github.com/jmorganca/ollama/version"
)
@@ -52,7 +50,7 @@ func Test_Routes(t *testing.T) {
createTestModel := func(t *testing.T, name string) {
fname := createTestFile(t, "ollama-model")
modelfile := strings.NewReader(fmt.Sprintf("FROM %s\nPARAMETER seed 42\nPARAMETER top_p 0.9\nPARAMETER stop foo\nPARAMETER stop bar", fname))
modelfile := strings.NewReader(fmt.Sprintf("FROM %s", fname))
commands, err := parser.Parse(modelfile)
assert.Nil(t, err)
fn := func(resp api.ProgressResponse) {
@@ -169,42 +167,6 @@ func Test_Routes(t *testing.T) {
assert.Equal(t, "beefsteak:latest", model.ShortName)
},
},
{
Name: "Show Model Handler",
Method: http.MethodPost,
Path: "/api/show",
Setup: func(t *testing.T, req *http.Request) {
createTestModel(t, "show-model")
showReq := api.ShowRequest{Model: "show-model"}
jsonData, err := json.Marshal(showReq)
assert.Nil(t, err)
req.Body = io.NopCloser(bytes.NewReader(jsonData))
},
Expected: func(t *testing.T, resp *http.Response) {
contentType := resp.Header.Get("Content-Type")
assert.Equal(t, contentType, "application/json; charset=utf-8")
body, err := io.ReadAll(resp.Body)
assert.Nil(t, err)
var showResp api.ShowResponse
err = json.Unmarshal(body, &showResp)
assert.Nil(t, err)
var params []string
paramsSplit := strings.Split(showResp.Parameters, "\n")
for _, p := range paramsSplit {
params = append(params, strings.Join(strings.Fields(p), " "))
}
sort.Strings(params)
expectedParams := []string{
"seed 42",
"stop \"bar\"",
"stop \"foo\"",
"top_p 0.9",
}
assert.Equal(t, expectedParams, params)
},
},
}
s, err := setupServer(t)
@@ -231,36 +193,13 @@ func Test_Routes(t *testing.T) {
}
resp, err := httpSrv.Client().Do(req)
assert.Nil(t, err)
defer resp.Body.Close()
assert.Nil(t, err)
if tc.Expected != nil {
tc.Expected(t, resp)
}
}
}
type MockLLM struct {
encoding []int
}
func (llm *MockLLM) Predict(ctx context.Context, pred llm.PredictOpts, fn func(llm.PredictResult)) error {
return nil
}
func (llm *MockLLM) Encode(ctx context.Context, prompt string) ([]int, error) {
return llm.encoding, nil
}
func (llm *MockLLM) Decode(ctx context.Context, tokens []int) (string, error) {
return "", nil
}
func (llm *MockLLM) Embedding(ctx context.Context, input string) ([]float64, error) {
return []float64{}, nil
}
func (llm *MockLLM) Close() {
// do nothing
}

View File

@@ -7,7 +7,7 @@ import (
"fmt"
"hash"
"io"
"log/slog"
"log"
"math"
"net/http"
"net/url"
@@ -88,7 +88,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
return nil
}
size := b.Total / numUploadParts
var size = b.Total / numUploadParts
switch {
case size < minUploadPartSize:
size = minUploadPartSize
@@ -107,7 +107,7 @@ func (b *blobUpload) Prepare(ctx context.Context, requestURL *url.URL, opts *Reg
offset += size
}
slog.Info(fmt.Sprintf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size)))
log.Printf("uploading %s in %d %s part(s)", b.Digest[7:19], len(b.Parts), format.HumanBytes(b.Parts[0].Size))
requestURL, err = url.Parse(location)
if err != nil {
@@ -156,7 +156,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
return err
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue
}
@@ -200,7 +200,7 @@ func (b *blobUpload) Run(ctx context.Context, opts *RegistryOptions) {
break
} else if err != nil {
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep))
log.Printf("%s complete upload attempt %d failed: %v, retrying in %s", b.Digest[7:19], try, err, sleep)
time.Sleep(sleep)
continue
}
@@ -265,7 +265,7 @@ func (b *blobUpload) uploadPart(ctx context.Context, method string, requestURL *
return err
case err != nil:
sleep := time.Second * time.Duration(math.Pow(2, float64(try)))
slog.Info(fmt.Sprintf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep))
log.Printf("%s part %d attempt %d failed: %v, retrying in %s", b.Digest[7:19], part.N, try, err, sleep)
time.Sleep(sleep)
continue
}
@@ -395,7 +395,6 @@ func uploadBlob(ctx context.Context, mp ModelPath, layer *Layer, opts *RegistryO
return err
}
// nolint: contextcheck
go upload.Run(context.Background(), opts)
}