Compare commits

..

3 Commits

Author SHA1 Message Date
jmorganca
1ab7631377 fix lint error 2025-11-06 13:55:25 -08:00
jmorganca
fed3665c70 fix tests 2025-11-06 13:49:17 -08:00
jmorganca
0a84939c11 api: add omitempty to required tool function parameter type 2025-11-06 12:58:15 -08:00
549 changed files with 44710 additions and 53789 deletions

4
.gitattributes vendored
View File

@@ -15,12 +15,8 @@ ml/backend/**/*.cu linguist-vendored
ml/backend/**/*.cuh linguist-vendored
ml/backend/**/*.m linguist-vendored
ml/backend/**/*.metal linguist-vendored
ml/backend/**/*.comp linguist-vendored
ml/backend/**/*.glsl linguist-vendored
ml/backend/**/CMakeLists.txt linguist-vendored
app/webview linguist-vendored
llama/build-info.cpp linguist-generated
ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.s linguist-generated

View File

@@ -16,15 +16,13 @@ jobs:
outputs:
GOFLAGS: ${{ steps.goflags.outputs.GOFLAGS }}
VERSION: ${{ steps.goflags.outputs.VERSION }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
- name: Set environment
id: goflags
run: |
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" | tee -a $GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
echo GOFLAGS="'-ldflags=-w -s \"-X=github.com/ollama/ollama/version.Version=${GITHUB_REF_NAME#v}\" \"-X=github.com/ollama/ollama/server.mode=release\"'" >>$GITHUB_OUTPUT
echo VERSION="${GITHUB_REF_NAME#v}" >>$GITHUB_OUTPUT
darwin-build:
runs-on: macos-14-xlarge
@@ -55,9 +53,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- run: |
./scripts/build_darwin.sh
- name: Log build results
@@ -109,13 +104,6 @@ jobs:
install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe
rocm-version: '6.2'
flags: '-DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"'
runner_dir: 'rocm'
- os: windows
arch: amd64
preset: Vulkan
install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe
flags: ''
runner_dir: 'vulkan'
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
env:
@@ -125,14 +113,13 @@ jobs:
run: |
choco install -y --no-progress ccache ninja
ccache -o cache_dir=${{ github.workspace }}\.ccache
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ') || startsWith(matrix.preset, 'Vulkan')
- if: startsWith(matrix.preset, 'CUDA ') || startsWith(matrix.preset, 'ROCm ')
id: cache-install
uses: actions/cache/restore@v4
with:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- if: startsWith(matrix.preset, 'CUDA ')
name: Install CUDA ${{ matrix.cuda-version }}
@@ -162,18 +149,6 @@ jobs:
echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append
echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append
- if: matrix.preset == 'Vulkan'
name: Install Vulkan ${{ matrix.rocm-version }}
run: |
$ErrorActionPreference = "Stop"
if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') {
Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe"
Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait
}
$vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path
echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV
- if: matrix.preset == 'CPU'
run: |
echo "CC=clang.exe" | Out-File -FilePath $env:GITHUB_ENV -Append
@@ -184,20 +159,19 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}-${{ needs.setup-environment.outputs.vendorsha }}
key: ccache-${{ matrix.os }}-${{ matrix.arch }}-${{ matrix.preset }}
- name: Build target "${{ matrix.preset }}"
run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
cmake --preset "${{ matrix.preset }}" ${{ matrix.flags }} --install-prefix "$((pwd).Path)\dist\${{ matrix.os }}-${{ matrix.arch }}"
cmake --build --parallel ([Environment]::ProcessorCount) --preset "${{ matrix.preset }}"
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || startsWith(matrix.preset, 'Vulkan') && 'Vulkan' || 'CPU' }}" --strip
cmake --install build --component "${{ startsWith(matrix.preset, 'CUDA ') && 'CUDA' || startsWith(matrix.preset, 'ROCm ') && 'HIP' || 'CPU' }}" --strip
Remove-Item -Path dist\lib\ollama\rocm\rocblas\library\*gfx906* -ErrorAction SilentlyContinue
env:
CMAKE_GENERATOR: Ninja
@@ -254,9 +228,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- name: Verify gcc is actually clang
run: |
$ErrorActionPreference='Continue'
@@ -310,9 +281,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/download-artifact@v4
with:
pattern: depends-windows*
@@ -344,13 +312,13 @@ jobs:
include:
- os: linux
arch: amd64
target: archive
target: archive_novulkan
- os: linux
arch: amd64
target: rocm
- os: linux
arch: arm64
target: archive
target: archive_novulkan
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -377,7 +345,6 @@ jobs:
bin/ollama) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/*.so*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_v*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/vulkan*) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}.tar.in ;;
lib/ollama/cuda_jetpack5) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack5.tar.in ;;
lib/ollama/cuda_jetpack6) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-jetpack6.tar.in ;;
lib/ollama/rocm) echo $COMPONENT >>ollama-${{ matrix.os }}-${{ matrix.arch }}-rocm.tar.in ;;
@@ -407,12 +374,14 @@ jobs:
include:
- os: linux
arch: arm64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
- os: linux
arch: amd64
target: novulkan
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
@@ -425,6 +394,14 @@ jobs:
CGO_CXXFLAGS
GOFLAGS
FLAVOR=rocm
- os: linux
arch: amd64
suffix: '-vulkan'
target: default
build-args: |
CGO_CFLAGS
CGO_CXXFLAGS
GOFLAGS
runs-on: ${{ matrix.arch == 'arm64' && format('{0}-{1}', matrix.os, matrix.arch) || matrix.os }}
environment: release
needs: setup-environment
@@ -442,6 +419,7 @@ jobs:
with:
context: .
platforms: ${{ matrix.os }}/${{ matrix.arch }}
target: ${{ matrix.preset }}
build-args: ${{ matrix.build-args }}
outputs: type=image,name=${{ vars.DOCKER_REPO }},push-by-digest=true,name-canonical=true,push=true
cache-from: type=registry,ref=${{ vars.DOCKER_REPO }}:latest

View File

@@ -22,7 +22,6 @@ jobs:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.changes.outputs.changed }}
vendorsha: ${{ steps.changes.outputs.vendorsha }}
steps:
- uses: actions/checkout@v4
with:
@@ -38,7 +37,6 @@ jobs:
}
echo changed=$(changed 'llama/llama.cpp/**/*' 'ml/backend/ggml/ggml/**/*') | tee -a $GITHUB_OUTPUT
echo vendorsha=$(make -f Makefile.sync print-base) | tee -a $GITHUB_OUTPUT
linux:
needs: [changes]
@@ -85,7 +83,7 @@ jobs:
- uses: actions/cache@v4
with:
path: /github/home/.cache/ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
cmake --preset ${{ matrix.preset }} ${{ matrix.flags }}
cmake --build --preset ${{ matrix.preset }} --parallel
@@ -174,13 +172,12 @@ jobs:
path: |
C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA
C:\Program Files\AMD\ROCm
C:\VulkanSDK
key: ${{ matrix.install }}
- uses: actions/checkout@v4
- uses: actions/cache@v4
with:
path: ${{ github.workspace }}\.ccache
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}-${{ needs.changes.outputs.vendorsha }}
key: ccache-${{ runner.os }}-${{ runner.arch }}-${{ matrix.preset }}
- run: |
Import-Module 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise\Common7\Tools\Microsoft.VisualStudio.DevShell.dll'
Enter-VsDevShell -VsInstallPath 'C:\Program Files\Microsoft Visual Studio\2022\Enterprise' -SkipAutomaticLocation -DevCmdArguments '-arch=x64 -no_logo'
@@ -208,9 +205,6 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: 'go.mod'
cache-dependency-path: |
go.sum
Makefile.sync
- uses: actions/setup-node@v4
with:
node-version: '20'
@@ -231,9 +225,12 @@ jobs:
if: always()
run: go test -count=1 -benchtime=1x ./...
- uses: golangci/golangci-lint-action@v9
# TODO(bmizerany): replace this heavy tool with just the
# tools/checks/binaries we want and then make them all run in parallel
# across jobs, not on a single tiny vm on Github Actions.
- uses: golangci/golangci-lint-action@v6
with:
only-new-issues: true
args: --timeout 10m0s -v
patches:
runs-on: ubuntu-latest
@@ -242,4 +239,4 @@ jobs:
- name: Verify patches apply cleanly and do not change files
run: |
make -f Makefile.sync clean checkout apply-patches sync
git diff --compact-summary --exit-code
git diff --compact-summary --exit-code

View File

@@ -1,4 +1,5 @@
version: "2"
run:
timeout: 5m
linters:
enable:
- asasalint
@@ -6,46 +7,35 @@ linters:
- bodyclose
- containedctx
- gocheckcompilerdirectives
- gofmt
- gofumpt
- gosimple
- govet
- ineffassign
- intrange
- makezero
- misspell
- nilerr
- nolintlint
- nosprintfhostport
- staticcheck
- unconvert
- usetesting
- wastedassign
- whitespace
disable:
- errcheck
- usestdlibvars
settings:
govet:
disable:
- unusedresult
staticcheck:
checks:
- all
- -QF* # disable quick fix suggestions
- -SA1019
- -ST1000 # package comment format
- -ST1003 # underscores in package names
- -ST1005 # error strings should not be capitalized
- -ST1012 # error var naming (ErrFoo)
- -ST1016 # receiver name consistency
- -ST1020 # comment on exported function format
- -ST1021 # comment on exported type format
- -ST1022 # comment on exported var format
- -ST1023 # omit type from declaration
- errcheck
linters-settings:
staticcheck:
checks:
- all
- -SA1019 # omit Deprecated check
severity:
default: error
default-severity: error
rules:
- linters:
- gofmt
- goimports
- intrange
severity: info
formatters:
enable:
- gofmt
- gofumpt

View File

@@ -16,7 +16,7 @@ See the [development documentation](./docs/development.md) for instructions on h
* New features: new features (e.g. API fields, environment variables) add surface area to Ollama and make it harder to maintain in the long run as they cannot be removed without potentially breaking users in the future.
* Refactoring: large code improvements are important, but can be harder or take longer to review and merge.
* Documentation: small updates to fill in or correct missing documentation are helpful, however large documentation additions can be hard to maintain over time.
* Documentation: small updates to fill in or correct missing documentation is helpful, however large documentation additions can be hard to maintain over time.
### Issues that may not be accepted
@@ -43,7 +43,7 @@ Tips for proposals:
* Explain how the change will be tested.
Additionally, for bonus points: Provide draft documentation you would expect to
see if the changes were accepted.
see if the change were accepted.
## Pull requests
@@ -66,6 +66,7 @@ Examples:
llm/backend/mlx: support the llama architecture
CONTRIBUTING: provide clarity on good commit messages, and bad
docs: simplify manual installation with shorter curl commands
Bad Examples:

View File

@@ -39,14 +39,14 @@ ENV CC=clang CXX=clang++
FROM base-${TARGETARCH} AS base
ARG CMAKEVERSION
RUN curl -fsSL https://github.com/Kitware/CMake/releases/download/v${CMAKEVERSION}/cmake-${CMAKEVERSION}-linux-$(uname -m).tar.gz | tar xz -C /usr/local --strip-components 1
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
ENV LDFLAGS=-s
FROM base AS cpu
RUN dnf install -y gcc-toolset-11-gcc gcc-toolset-11-gcc-c++
ENV PATH=/opt/rh/gcc-toolset-11/root/usr/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CPU' \
&& cmake --build --parallel ${PARALLEL} --preset 'CPU' \
@@ -57,8 +57,6 @@ ARG CUDA11VERSION=11.8
RUN dnf install -y cuda-toolkit-${CUDA11VERSION//./-}
ENV PATH=/usr/local/cuda-11/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 11' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 11' \
@@ -69,8 +67,6 @@ ARG CUDA12VERSION=12.8
RUN dnf install -y cuda-toolkit-${CUDA12VERSION//./-}
ENV PATH=/usr/local/cuda-12/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 12' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 12' \
@@ -82,8 +78,6 @@ ARG CUDA13VERSION=13.0
RUN dnf install -y cuda-toolkit-${CUDA13VERSION//./-}
ENV PATH=/usr/local/cuda-13/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'CUDA 13' \
&& cmake --build --parallel ${PARALLEL} --preset 'CUDA 13' \
@@ -93,8 +87,6 @@ RUN --mount=type=cache,target=/root/.ccache \
FROM base AS rocm-6
ENV PATH=/opt/rocm/hcc/bin:/opt/rocm/hip/bin:/opt/rocm/bin:/opt/rocm/hcc/bin:$PATH
ARG PARALLEL
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'ROCm 6' \
&& cmake --build --parallel ${PARALLEL} --preset 'ROCm 6' \
@@ -126,8 +118,6 @@ RUN --mount=type=cache,target=/root/.ccache \
&& cmake --install build --component CUDA --strip --parallel ${PARALLEL}
FROM base AS vulkan
COPY CMakeLists.txt CMakePresets.json .
COPY ml/backend/ggml/ggml ml/backend/ggml/ggml
RUN --mount=type=cache,target=/root/.ccache \
cmake --preset 'Vulkan' \
&& cmake --build --parallel --preset 'Vulkan' \
@@ -169,7 +159,32 @@ ARG VULKANVERSION
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04
# Temporary opt-out stages for Vulkan
FROM --platform=linux/amd64 scratch AS amd64_novulkan
# COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/
COPY --from=cuda-12 dist/lib/ollama /lib/ollama/
COPY --from=cuda-13 dist/lib/ollama /lib/ollama/
FROM arm64 AS arm64_novulkan
FROM ${FLAVOR}_novulkan AS archive_novulkan
COPY --from=cpu dist/lib/ollama /lib/ollama
COPY --from=build /bin/ollama /bin/ollama
FROM ubuntu:24.04 AS novulkan
RUN apt-get update \
&& apt-get install -y ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
COPY --from=archive_novulkan /bin /usr/bin
ENV PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
COPY --from=archive_novulkan /lib/ollama /usr/lib/ollama
ENV LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64
ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility
ENV NVIDIA_VISIBLE_DEVICES=all
ENV OLLAMA_HOST=0.0.0.0:11434
EXPOSE 11434
ENTRYPOINT ["/bin/ollama"]
CMD ["serve"]
FROM ubuntu:24.04 AS default
RUN apt-get update \
&& apt-get install -y ca-certificates libvulkan1 \
&& apt-get clean \

View File

@@ -1,6 +1,6 @@
UPSTREAM=https://github.com/ggml-org/llama.cpp.git
WORKDIR=llama/vendor
FETCH_HEAD=17f7f4baad8b3a716ee139da7bb56ae984e8c0fa
FETCH_HEAD=3cfa9c3f125763305b4226bc032f1954f08990dc
.PHONY: help
help:
@@ -57,7 +57,7 @@ checkout: $(WORKDIR)
$(WORKDIR):
git clone $(UPSTREAM) $(WORKDIR)
.PHONY: format-patches
.PHONE: format-patches
format-patches: llama/patches
git -C $(WORKDIR) format-patch \
--no-signature \
@@ -66,11 +66,7 @@ format-patches: llama/patches
-o $(realpath $<) \
$(FETCH_HEAD)
.PHONY: clean
.PHONE: clean
clean: checkout
@git -C $(WORKDIR) am --abort || true
$(RM) llama/patches/.*.patched
.PHONY: print-base
print-base:
@echo $(FETCH_HEAD)

View File

@@ -299,7 +299,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LibreChat](https://github.com/danny-avila/LibreChat)
- [Bionic GPT](https://github.com/bionic-gpt/bionic-gpt)
- [HTML UI](https://github.com/rtcfirefly/ollama-ui)
- [AI-UI](https://github.com/bajahaw/ai-ui)
- [Saddle](https://github.com/jikkuatwork/saddle)
- [TagSpaces](https://www.tagspaces.org) (A platform for file-based apps, [utilizing Ollama](https://docs.tagspaces.org/ai/) for the generation of tags and descriptions)
- [Chatbot UI](https://github.com/ivanfioravanti/chatbot-ollama)
@@ -366,8 +365,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [PartCAD](https://github.com/openvmp/partcad/) (CAD model generation with OpenSCAD and CadQuery)
- [Ollama4j Web UI](https://github.com/ollama4j/ollama4j-web-ui) - Java-based Web UI for Ollama built with Vaadin, Spring Boot, and Ollama4j
- [PyOllaMx](https://github.com/kspviswa/pyOllaMx) - macOS application capable of chatting with both Ollama and Apple MLX models.
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VS Code extension for multi-file/whole-repo coding
- [Void](https://github.com/voideditor/void) (Open source AI code editor and Cursor alternative)
- [Cline](https://github.com/cline/cline) - Formerly known as Claude Dev is a VSCode extension for multi-file/whole-repo coding
- [Cherry Studio](https://github.com/kangfenmao/cherry-studio) (Desktop client with Ollama support)
- [ConfiChat](https://github.com/1runeberg/confichat) (Lightweight, standalone, multi-platform, and privacy-focused LLM chat interface with optional encryption)
- [Archyve](https://github.com/nickthecook/archyve) (RAG-enabling document library)
@@ -399,7 +397,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [aidful-ollama-model-delete](https://github.com/AidfulAI/aidful-ollama-model-delete) (User interface for simplified model cleanup)
- [Perplexica](https://github.com/ItzCrazyKns/Perplexica) (An AI-powered search engine & an open-source alternative to Perplexity AI)
- [Ollama Chat WebUI for Docker ](https://github.com/oslook/ollama-webui) (Support for local docker deployment, lightweight ollama webui)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VS Code extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [AI Toolkit for Visual Studio Code](https://aka.ms/ai-tooklit/ollama-docs) (Microsoft-official VSCode extension to chat, test, evaluate models with Ollama support, and use them in your AI applications.)
- [MinimalNextOllamaChat](https://github.com/anilkay/MinimalNextOllamaChat) (Minimal Web UI for Chat and Model Control)
- [Chipper](https://github.com/TilmanGriesel/chipper) AI interface for tinkerers (Ollama, Haystack RAG, Python)
- [ChibiChat](https://github.com/CosmicEventHorizon/ChibiChat) (Kotlin-based Android app to chat with Ollama and Koboldcpp API endpoints)
@@ -428,7 +426,6 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Mayan EDMS](https://gitlab.com/mayan-edms/mayan-edms) (Open source document management system to organize, tag, search, and automate your files with powerful Ollama driven workflows.)
- [Serene Pub](https://github.com/doolijb/serene-pub) (Beginner friendly, open source AI Roleplaying App for Windows, Mac OS and Linux. Search, download and use models with Ollama all inside the app.)
- [Andes](https://github.com/aqerd/andes) (A Visual Studio Code extension that provides a local UI interface for Ollama models)
- [KDeps](https://github.com/kdeps/kdeps) (Kdeps is an offline-first AI framework for building Dockerized full-stack AI applications declaratively using Apple PKL and integrates APIs with Ollama on the backend.)
- [Clueless](https://github.com/KashyapTan/clueless) (Open Source & Local Cluely: A desktop application LLM assistant to help you talk to anything on your screen using locally served Ollama models. Also undetectable to screenshare)
- [ollama-co2](https://github.com/carbonatedWaterOrg/ollama-co2) (FastAPI web interface for monitoring and managing local and remote Ollama servers with real-time model monitoring and concurrent downloads)
- [Hillnote](https://hillnote.com) (A Markdown-first workspace designed to supercharge your AI workflow. Create documents ready to integrate with Claude, ChatGPT, Gemini, Cursor, and more - all while keeping your work on your device.)
@@ -555,7 +552,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [Parakeet](https://github.com/parakeet-nest/parakeet) is a GoLang library, made to simplify the development of small generative AI applications with Ollama.
- [Haverscript](https://github.com/andygill/haverscript) with [examples](https://github.com/andygill/haverscript/tree/main/examples)
- [Ollama for Swift](https://github.com/mattt/ollama-swift)
- [Swollama for Swift]([https://github.com/marcusziade/Swollama](https://github.com/guitaripod/Swollama) with [DocC]( https://guitaripod.github.io/Swollama/documentation/swollama)
- [Swollama for Swift](https://github.com/marcusziade/Swollama) with [DocC](https://marcusziade.github.io/Swollama/documentation/swollama/)
- [GoLamify](https://github.com/prasad89/golamify)
- [Ollama for Haskell](https://github.com/tusharad/ollama-haskell)
- [multi-llm-ts](https://github.com/nbonamy/multi-llm-ts) (A Typescript/JavaScript library allowing access to different LLM in a unified API)
@@ -618,7 +615,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LSP-AI](https://github.com/SilasMarvin/lsp-ai) (Open-source language server for AI-powered functionality)
- [QodeAssist](https://github.com/Palm1r/QodeAssist) (AI-powered coding assistant plugin for Qt Creator)
- [Obsidian Quiz Generator plugin](https://github.com/ECuiDev/obsidian-quiz-generator)
- [AI Summary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [AI Summmary Helper plugin](https://github.com/philffm/ai-summary-helper)
- [TextCraft](https://github.com/suncloudsmoon/TextCraft) (Copilot in Word alternative using Ollama)
- [Alfred Ollama](https://github.com/zeitlings/alfred-ollama) (Alfred Workflow)
- [TextLLaMA](https://github.com/adarshM84/TextLLaMA) A Chrome Extension that helps you write emails, correct grammar, and translate into any language
@@ -626,7 +623,7 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [LLM Telegram Bot](https://github.com/innightwolfsleep/llm_telegram_bot) (telegram bot, primary for RP. Oobabooga-like buttons, [A1111](https://github.com/AUTOMATIC1111/stable-diffusion-webui) API integration e.t.c)
- [mcp-llm](https://github.com/sammcj/mcp-llm) (MCP Server to allow LLMs to call other LLMs)
- [SimpleOllamaUnity](https://github.com/HardCodeDev777/SimpleOllamaUnity) (Unity Engine extension for communicating with Ollama in a few lines of code. Also works at runtime)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Editor tool to analyze scripts via Ollama)
- [UnityCodeLama](https://github.com/HardCodeDev777/UnityCodeLama) (Unity Edtior tool to analyze scripts via Ollama)
- [NativeMind](https://github.com/NativeMindBrowser/NativeMindExtension) (Private, on-device AI Assistant, no cloud dependencies)
- [GMAI - Gradle Managed AI](https://gmai.premex.se/) (Gradle plugin for automated Ollama lifecycle management during build phases)
- [NOMYO Router](https://github.com/nomyo-ai/nomyo-router) (A transparent Ollama proxy with model deployment aware routing which auto-manages multiple Ollama instances in a given network)
@@ -636,12 +633,9 @@ See the [API documentation](./docs/api.md) for all endpoints.
- [llama.cpp](https://github.com/ggml-org/llama.cpp) project founded by Georgi Gerganov.
### Observability
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native integration to Ollama.
- [Opik](https://www.comet.com/docs/opik/cookbook/ollama) is an open-source platform to debug, evaluate, and monitor your LLM applications, RAG systems, and agentic workflows with comprehensive tracing, automated evaluations, and production-ready dashboards. Opik supports native intergration to Ollama.
- [Lunary](https://lunary.ai/docs/integrations/ollama) is the leading open-source LLM observability platform. It provides a variety of enterprise-grade features such as real-time analytics, prompt templates management, PII masking, and comprehensive agent tracing.
- [OpenLIT](https://github.com/openlit/openlit) is an OpenTelemetry-native tool for monitoring Ollama Applications & GPUs using traces and metrics.
- [HoneyHive](https://docs.honeyhive.ai/integrations/ollama) is an AI observability and evaluation platform for AI agents. Use HoneyHive to evaluate agent performance, interrogate failures, and monitor quality in production.
- [Langfuse](https://langfuse.com/docs/integrations/ollama) is an open source LLM observability platform that enables teams to collaboratively monitor, evaluate and debug AI applications.
- [MLflow Tracing](https://mlflow.org/docs/latest/llms/tracing/index.html#automatic-tracing) is an open source LLM observability tool with a convenient API to log and visualize traces, making it easy to debug and evaluate GenAI applications.
### Security
- [Ollama Fortress](https://github.com/ParisNeo/ollama_proxy_server)

View File

@@ -14,7 +14,7 @@ Please include the following details in your report:
## Security best practices
While the maintainer team does its best to secure Ollama, users are encouraged to implement their own security best practices, such as:
While the maintainer team does their best to secure Ollama, users are encouraged to implement their own security best practices, such as:
- Regularly updating to the latest version of Ollama
- Securing access to hosted instances of Ollama

View File

@@ -226,14 +226,7 @@ func (c *Client) stream(ctx context.Context, method, path string, data any, fn f
bts := scanner.Bytes()
if err := json.Unmarshal(bts, &errorResponse); err != nil {
if response.StatusCode >= http.StatusBadRequest {
return StatusError{
StatusCode: response.StatusCode,
Status: response.Status,
ErrorMessage: string(bts),
}
}
return errors.New(string(bts))
return fmt.Errorf("unmarshal: %w", err)
}
if response.StatusCode == http.StatusUnauthorized {

View File

@@ -55,7 +55,6 @@ func TestClientFromEnvironment(t *testing.T) {
type testError struct {
message string
statusCode int
raw bool // if true, write message as-is instead of JSON encoding
}
func (e testError) Error() string {
@@ -112,20 +111,6 @@ func TestClientStream(t *testing.T) {
},
},
},
{
name: "plain text error response",
responses: []any{
"internal server error",
},
wantErr: "internal server error",
},
{
name: "HTML error page",
responses: []any{
"<html><body>404 Not Found</body></html>",
},
wantErr: "404 Not Found",
},
}
for _, tc := range testCases {
@@ -150,12 +135,6 @@ func TestClientStream(t *testing.T) {
return
}
if str, ok := resp.(string); ok {
fmt.Fprintln(w, str)
flusher.Flush()
continue
}
if err := json.NewEncoder(w).Encode(resp); err != nil {
t.Fatalf("failed to encode response: %v", err)
}
@@ -194,10 +173,9 @@ func TestClientStream(t *testing.T) {
func TestClientDo(t *testing.T) {
testCases := []struct {
name string
response any
wantErr string
wantStatusCode int
name string
response any
wantErr string
}{
{
name: "immediate error response",
@@ -205,8 +183,7 @@ func TestClientDo(t *testing.T) {
message: "test error message",
statusCode: http.StatusBadRequest,
},
wantErr: "test error message",
wantStatusCode: http.StatusBadRequest,
wantErr: "test error message",
},
{
name: "server error response",
@@ -214,8 +191,7 @@ func TestClientDo(t *testing.T) {
message: "internal error",
statusCode: http.StatusInternalServerError,
},
wantErr: "internal error",
wantStatusCode: http.StatusInternalServerError,
wantErr: "internal error",
},
{
name: "successful response",
@@ -227,26 +203,6 @@ func TestClientDo(t *testing.T) {
Success: true,
},
},
{
name: "plain text error response",
response: testError{
message: "internal server error",
statusCode: http.StatusInternalServerError,
raw: true,
},
wantErr: "internal server error",
wantStatusCode: http.StatusInternalServerError,
},
{
name: "HTML error page",
response: testError{
message: "<html><body>404 Not Found</body></html>",
statusCode: http.StatusNotFound,
raw: true,
},
wantErr: "<html><body>404 Not Found</body></html>",
wantStatusCode: http.StatusNotFound,
},
}
for _, tc := range testCases {
@@ -254,16 +210,11 @@ func TestClientDo(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if errResp, ok := tc.response.(testError); ok {
w.WriteHeader(errResp.statusCode)
if !errResp.raw {
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
} else {
// Write raw message (simulates non-JSON error responses)
fmt.Fprint(w, errResp.message)
err := json.NewEncoder(w).Encode(map[string]string{
"error": errResp.message,
})
if err != nil {
t.Fatal("failed to encode error response:", err)
}
return
}
@@ -290,15 +241,6 @@ func TestClientDo(t *testing.T) {
if err.Error() != tc.wantErr {
t.Errorf("error message mismatch: got %q, want %q", err.Error(), tc.wantErr)
}
if tc.wantStatusCode != 0 {
if statusErr, ok := err.(StatusError); ok {
if statusErr.StatusCode != tc.wantStatusCode {
t.Errorf("status code mismatch: got %d, want %d", statusErr.StatusCode, tc.wantStatusCode)
}
} else {
t.Errorf("expected StatusError, got %T", err)
}
}
return
}

View File

@@ -15,19 +15,19 @@ func main() {
}
messages := []api.Message{
{
api.Message{
Role: "system",
Content: "Provide very brief, concise responses",
},
{
api.Message{
Role: "user",
Content: "Name some unusual animals",
},
{
api.Message{
Role: "assistant",
Content: "Monotreme, platypus, echidna",
},
{
api.Message{
Role: "user",
Content: "which of these is the most dangerous?",
},

View File

@@ -117,14 +117,6 @@ type GenerateRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
// ChatRequest describes a request sent by [Client.Chat].
@@ -167,14 +159,6 @@ type ChatRequest struct {
// DebugRenderOnly is a debug option that, when set to true, returns the rendered
// template instead of calling the model.
DebugRenderOnly bool `json:"_debug_render_only,omitempty"`
// Logprobs specifies whether to return log probabilities of the output tokens.
Logprobs bool `json:"logprobs,omitempty"`
// TopLogprobs is the number of most likely tokens to return at each token position,
// each with an associated log probability. Only applies when Logprobs is true.
// Valid values are 0-20. Default is 0 (only return the selected token's logprob).
TopLogprobs int `json:"top_logprobs,omitempty"`
}
type Tools []Tool
@@ -359,27 +343,6 @@ func (t *ToolFunction) String() string {
return string(bts)
}
// TokenLogprob represents log probability information for a single token alternative.
type TokenLogprob struct {
// Token is the text representation of the token.
Token string `json:"token"`
// Logprob is the log probability of this token.
Logprob float64 `json:"logprob"`
// Bytes contains the raw byte representation of the token
Bytes []int `json:"bytes,omitempty"`
}
// Logprob contains log probability information for a generated token.
type Logprob struct {
TokenLogprob
// TopLogprobs contains the most likely tokens and their log probabilities
// at this position, if requested via TopLogprobs parameter.
TopLogprobs []TokenLogprob `json:"top_logprobs,omitempty"`
}
// ChatResponse is the response returned by [Client.Chat]. Its fields are
// similar to [GenerateResponse].
type ChatResponse struct {
@@ -406,10 +369,6 @@ type ChatResponse struct {
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
Metrics
}
@@ -718,10 +677,6 @@ type GenerateResponse struct {
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
DebugInfo *DebugInfo `json:"_debug_info,omitempty"`
// Logprobs contains log probability information for the generated tokens,
// if requested via the Logprobs parameter.
Logprobs []Logprob `json:"logprobs,omitempty"`
}
// ModelDetails provides details about a model.

View File

@@ -48,6 +48,16 @@ The `-dev` flag enables:
- CORS headers for cross-origin requests
- Hot-reload support for UI development
#### Run Storybook
Inside the `ui/app` directory, run:
```bash
npm run storybook
```
For now we're writing stories as siblings of the component they're testing. So for example, `src/components/Message.stories.tsx` is the story for `src/components/Message.tsx`.
## Build

View File

@@ -273,6 +273,10 @@ func main() {
Handler: uiServer.Handler(),
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
// Start the UI server
slog.Info("starting ui server", "port", port)
go func() {
@@ -316,17 +320,6 @@ func main() {
slog.Debug("no URL scheme request to handle")
}
go func() {
slog.Debug("waiting for ollama server to be ready")
if err := ui.WaitForServer(ctx, 10*time.Second); err != nil {
slog.Warn("ollama server not ready, continuing anyway", "error", err)
}
if _, err := uiServer.UserData(ctx); err != nil {
slog.Warn("failed to load user data", "error", err)
}
}()
osRun(cancel, hasCompletedFirstRun, startHidden)
slog.Info("shutting down desktop server")
@@ -368,7 +361,7 @@ func checkUserLoggedIn(uiServerPort int) bool {
return false
}
resp, err := http.Post(fmt.Sprintf("http://127.0.0.1:%d/api/me", uiServerPort), "application/json", nil)
resp, err := http.Get(fmt.Sprintf("http://127.0.0.1:%d/api/v1/me", uiServerPort))
if err != nil {
slog.Debug("failed to call local auth endpoint", "error", err)
return false
@@ -404,8 +397,8 @@ func checkUserLoggedIn(uiServerPort int) bool {
// handleConnectURLScheme fetches the connect URL and opens it in the browser
func handleConnectURLScheme() {
if checkUserLoggedIn(uiServerPort) {
slog.Info("user is already logged in, opening app instead")
showWindow(wv.webview.Window())
slog.Info("user is already logged in, opening settings instead")
sendUIRequestMessage("/")
return
}
@@ -441,30 +434,37 @@ func openInBrowser(url string) {
}
}
// parseURLScheme parses an ollama:// URL and validates it
// Supports: ollama:// (open app) and ollama://connect (OAuth)
func parseURLScheme(urlSchemeRequest string) (isConnect bool, err error) {
// parseURLScheme parses an ollama:// URL and returns whether it's a connect URL and the UI path
func parseURLScheme(urlSchemeRequest string) (isConnect bool, uiPath string, err error) {
parsedURL, err := url.Parse(urlSchemeRequest)
if err != nil {
return false, fmt.Errorf("invalid URL: %w", err)
return false, "", err
}
// Check if this is a connect URL
if parsedURL.Host == "connect" || strings.TrimPrefix(parsedURL.Path, "/") == "connect" {
return true, nil
return true, "", nil
}
// Allow bare ollama:// or ollama:/// to open the app
if (parsedURL.Host == "" && parsedURL.Path == "") || parsedURL.Path == "/" {
return false, nil
// Extract the UI path
path := "/"
if parsedURL.Path != "" && parsedURL.Path != "/" {
// For URLs like ollama:///settings, use the path directly
path = parsedURL.Path
} else if parsedURL.Host != "" {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
// This also handles ollama://settings/ where Windows adds a trailing slash
path = "/" + parsedURL.Host
}
return false, fmt.Errorf("unsupported ollama:// URL path: %s", urlSchemeRequest)
return false, path, nil
}
// handleURLSchemeInCurrentInstance processes URL scheme requests in the current instance
func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
isConnect, err := parseURLScheme(urlSchemeRequest)
isConnect, uiPath, err := parseURLScheme(urlSchemeRequest)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlSchemeRequest, "error", err)
return
@@ -473,8 +473,6 @@ func handleURLSchemeInCurrentInstance(urlSchemeRequest string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage(uiPath)
}
}

View File

@@ -24,14 +24,27 @@ bool firstTimeRun,startHidden; // Set in run before initialization
for (NSURL *url in urls) {
if ([url.scheme isEqualToString:@"ollama"]) {
NSString *path = url.path;
if (path && ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"])) {
if (!path || [path isEqualToString:@""]) {
// For URLs like ollama://settings (without triple slash),
// the "settings" part is parsed as the host, not the path.
// We need to convert it to a path by prepending "/"
if (url.host && ![url.host isEqualToString:@""]) {
path = [@"/" stringByAppendingString:url.host];
} else {
path = @"/";
}
}
if ([path isEqualToString:@"/connect"] || [url.host isEqualToString:@"connect"]) {
// Special case: handle connect by opening browser instead of app
handleConnectURL();
} else {
// Set app to be active and visible
[NSApp setActivationPolicy:NSApplicationActivationPolicyRegular];
[NSApp activateIgnoringOtherApps:YES];
// Open the path with the UI
[self uiRequest:path];
}
break;
@@ -247,7 +260,7 @@ bool firstTimeRun,startHidden; // Set in run before initialization
}
- (void)openHelp:(id)sender {
NSURL *url = [NSURL URLWithString:@"https://docs.ollama.com/"];
NSURL *url = [NSURL URLWithString:@"https://github.com/ollama/ollama/tree/main/docs"];
[[NSWorkspace sharedWorkspace] openURL:url];
}

View File

@@ -138,7 +138,7 @@ func (app *appCallbacks) HandleURLScheme(urlScheme string) {
// handleURLSchemeRequest processes URL scheme requests from other instances
func handleURLSchemeRequest(urlScheme string) {
isConnect, err := parseURLScheme(urlScheme)
isConnect, uiPath, err := parseURLScheme(urlScheme)
if err != nil {
slog.Error("failed to parse URL scheme request", "url", urlScheme, "error", err)
return
@@ -147,9 +147,7 @@ func handleURLSchemeRequest(urlScheme string) {
if isConnect {
handleConnectURLScheme()
} else {
if wv.webview != nil {
showWindow(wv.webview.Window())
}
sendUIRequestMessage(uiPath)
}
}

View File

@@ -282,7 +282,7 @@ func (w *Webview) Run(path string) unsafe.Pointer {
"go", "rs", "swift", "kt", "scala", "sh", "bat", "yaml", "yml", "toml", "ini",
"cfg", "conf", "log", "rtf",
}
imageExts := []string{"png", "jpg", "jpeg", "webp"}
imageExts := []string{"png", "jpg", "jpeg"}
allowedExts := append(textExts, imageExts...)
// Use native multiple file selection with extension filtering

View File

@@ -469,24 +469,26 @@ export class HealthResponse {
}
export class User {
id: string;
email: string;
name: string;
bio?: string;
avatarurl?: string;
firstname?: string;
lastname?: string;
plan?: string;
email: string;
avatarURL: string;
plan: string;
bio: string;
firstName: string;
lastName: string;
overThreshold: boolean;
constructor(source: any = {}) {
if ('string' === typeof source) source = JSON.parse(source);
this.id = source["id"];
this.email = source["email"];
this.name = source["name"];
this.bio = source["bio"];
this.avatarurl = source["avatarurl"];
this.firstname = source["firstname"];
this.lastname = source["lastname"];
this.email = source["email"];
this.avatarURL = source["avatarURL"];
this.plan = source["plan"];
this.bio = source["bio"];
this.firstName = source["firstName"];
this.lastName = source["lastName"];
this.overThreshold = source["overThreshold"];
}
}
export class Attachment {

File diff suppressed because it is too large Load Diff

View File

@@ -34,7 +34,6 @@
"rehype-raw": "^7.0.0",
"rehype-sanitize": "^6.0.0",
"remark-math": "^6.0.0",
"streamdown": "^1.4.0",
"unist-builder": "^4.0.0",
"unist-util-parents": "^3.0.0"
},

View File

@@ -15,7 +15,6 @@ import {
import { parseJsonlFromResponse } from "./util/jsonl-parsing";
import { ollamaClient as ollama } from "./lib/ollama-client";
import type { ModelResponse } from "ollama/browser";
import { API_BASE, OLLAMA_DOT_COM } from "./lib/config";
// Extend Model class with utility methods
declare module "@/gotypes" {
@@ -27,6 +26,9 @@ declare module "@/gotypes" {
Model.prototype.isCloud = function (): boolean {
return this.model.endsWith("cloud");
};
const API_BASE = import.meta.env.DEV ? "http://127.0.0.1:3001" : "";
// Helper function to convert Uint8Array to base64
function uint8ArrayToBase64(uint8Array: Uint8Array): string {
const chunkSize = 0x8000; // 32KB chunks to avoid stack overflow
@@ -41,50 +43,44 @@ function uint8ArrayToBase64(uint8Array: Uint8Array): string {
}
export async function fetchUser(): Promise<User | null> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
});
try {
const response = await fetch(`${API_BASE}/api/v1/me`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.ok) {
const userData: User = await response.json();
if (userData.avatarurl && !userData.avatarurl.startsWith("http")) {
userData.avatarurl = `${OLLAMA_DOT_COM}${userData.avatarurl}`;
if (response.ok) {
const userData: User = await response.json();
return userData;
}
return userData;
}
if (response.status === 401 || response.status === 403) {
return null;
} catch (error) {
console.error("Error fetching user:", error);
return null;
}
throw new Error(`Failed to fetch user: ${response.status}`);
}
export async function fetchConnectUrl(): Promise<string> {
const response = await fetch(`${API_BASE}/api/me`, {
method: "POST",
const response = await fetch(`${API_BASE}/api/v1/connect`, {
method: "GET",
headers: {
"Content-Type": "application/json",
},
});
if (response.status === 401) {
const data = await response.json();
if (data.signin_url) {
return data.signin_url;
}
if (!response.ok) {
throw new Error("Failed to fetch connect URL");
}
throw new Error("Failed to fetch connect URL");
const data = await response.json();
return data.connect_url;
}
export async function disconnectUser(): Promise<void> {
const response = await fetch(`${API_BASE}/api/signout`, {
const response = await fetch(`${API_BASE}/api/v1/disconnect`, {
method: "POST",
headers: {
"Content-Type": "application/json",
@@ -209,11 +205,6 @@ export async function* sendMessage(
data: uint8ArrayToBase64(att.data),
}));
// Send think parameter when it's explicitly set (true, false, or a non-empty string).
const shouldSendThink =
think !== undefined &&
(typeof think === "boolean" || (typeof think === "string" && think !== ""));
const response = await fetch(`${API_BASE}/api/v1/chat/${chatId}`, {
method: "POST",
headers: {
@@ -231,7 +222,7 @@ export async function* sendMessage(
web_search: webSearch ?? false,
file_tools: fileTools ?? false,
...(forceUpdate !== undefined ? { forceUpdate } : {}),
...(shouldSendThink ? { think } : {}),
...(think !== undefined ? { think } : {}),
}),
),
signal,
@@ -394,8 +385,7 @@ export async function getInferenceCompute(): Promise<InferenceCompute[]> {
export async function fetchHealth(): Promise<boolean> {
try {
// Use the /api/version endpoint as a health check
const response = await fetch(`${API_BASE}/api/version`, {
const response = await fetch(`${API_BASE}/api/v1/health`, {
method: "GET",
headers: {
"Content-Type": "application/json",
@@ -404,8 +394,7 @@ export async function fetchHealth(): Promise<boolean> {
if (response.ok) {
const data = await response.json();
// If we get a version back, the server is healthy
return !!data.version;
return data.healthy || false;
}
return false;

View File

@@ -299,9 +299,9 @@ export default function Settings() {
</Button>
</div>
</div>
{user?.avatarurl && (
{user?.avatarURL && (
<img
src={user.avatarurl}
src={user.avatarURL}
alt={user?.name}
className="h-10 w-10 rounded-full bg-neutral-200 dark:bg-neutral-700 flex-shrink-0"
onError={(e) => {

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,522 @@
import { expect, test, suite } from "vitest";
import { processStreamingMarkdown } from "@/utils/processStreamingMarkdown";
suite("common llm outputs that cause issues", () => {
test("prefix of bolded list item shouldn't make a horizontal line", () => {
// we're going to go in order of incrementally adding characters. This
// happens really commonly with LLMs that like to make lists like so:
//
// * **point 1**: explanatory text
// * **point 2**: more explanatory text
//
// Partial rendering of `*` (A), followed by `* *` (B), followed by `* **`
// (C) is a total mess. (A) renders as a single bullet point in an
// otherwise empty list, (B) renders as two nested lists (and therefore
// two bullet points, styled differently by default in html), and (C)
// renders as a horizontal line because in markdown apparently `***` or `*
// * *` horizontal rules don't have as strict whitespace rules as I
// expected them to
// these are alone (i.e., they would be the first list item)
expect(processStreamingMarkdown("*")).toBe("");
expect(processStreamingMarkdown("* *")).toBe("");
expect(processStreamingMarkdown("* **")).toBe("");
// expect(processStreamingMarkdown("* **b")).toBe("* **b**");
// with a list item before them
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"*"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* *"
].join("\n"),
),
).toBe("* abc");
expect(
processStreamingMarkdown(
// prettier-ignore
[
"* abc",
"* **"
].join("\n"),
),
).toBe("* abc");
});
test("bolded list items with text should be rendered properly", () => {
expect(processStreamingMarkdown("* **abc**")).toBe("* **abc**");
});
test("partially bolded list items should be autoclosed", () => {
expect(processStreamingMarkdown("* **abc")).toBe("* **abc**");
});
suite(
"partially bolded list items should be autoclosed, even if the last node isn't a text node",
() => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`*"),
).toBe("* **Asynchronous Function `async`**");
});
},
);
});
suite("autoclosing bold", () => {
suite("endings with no asterisks", () => {
test("should autoclose bold", () => {
expect(processStreamingMarkdown("**abc")).toBe("**abc**");
expect(processStreamingMarkdown("abc **abc")).toBe("abc **abc**");
});
suite("should autoclose, even if the last node isn't a text node", () => {
test("inline code", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function `async`"),
).toBe("* **Asynchronous Function `async`**");
});
test("opening ** is at the end of the text", () => {
expect(processStreamingMarkdown("abc **`def` jhk [lmn](opq)")).toBe(
"abc **`def` jhk [lmn](opq)**",
);
});
test("if there's a space after the **, it should NOT be autoclosed", () => {
expect(processStreamingMarkdown("abc ** `def` jhk [lmn](opq)")).toBe(
"abc \\*\\* `def` jhk [lmn](opq)",
);
});
});
test("should autoclose bold, even if the last node isn't a text node", () => {
expect(
processStreamingMarkdown("* **Asynchronous Function ( `async`"),
).toBe("* **Asynchronous Function ( `async`**");
});
test("whitespace fakeouts should not be modified", () => {
expect(processStreamingMarkdown("** abc")).toBe("\\*\\* abc");
});
// TODO(drifkin): arguably this should just be removed entirely, but empty
// isn't so bad
test("should handle empty bolded items", () => {
expect(processStreamingMarkdown("**")).toBe("");
});
});
suite("partially closed bolded items", () => {
test("simple partial", () => {
expect(processStreamingMarkdown("**abc*")).toBe("**abc**");
});
test("partial with non-text node at end", () => {
expect(processStreamingMarkdown("**abc`def`*")).toBe("**abc`def`**");
});
test("partial with multiply nested ending nodes", () => {
expect(processStreamingMarkdown("**abc[abc](`def`)*")).toBe(
"**abc[abc](`def`)**",
);
});
test("normal emphasis should not be affected", () => {
expect(processStreamingMarkdown("*abc*")).toBe("*abc*");
});
test("normal emphasis with nested code should not be affected", () => {
expect(processStreamingMarkdown("*`abc`*")).toBe("*`abc`*");
});
});
test.skip("shouldn't autoclose immediately if there's a space before the closing *", () => {
expect(processStreamingMarkdown("**abc *")).toBe("**abc**");
});
// skipping for now because this requires partial link completion as well
suite.skip("nested blocks that each need autoclosing", () => {
test("emph nested in link nested in strong nested in list item", () => {
expect(processStreamingMarkdown("* **[abc **def")).toBe(
"* **[abc **def**]()**",
);
});
test("* **[ab *`def`", () => {
expect(processStreamingMarkdown("* **[ab *`def`")).toBe(
"* **[ab *`def`*]()**",
);
});
});
});
suite("numbered list items", () => {
test("should remove trailing numbers", () => {
expect(processStreamingMarkdown("1. First\n2")).toBe("1. First");
});
test("should remove trailing numbers with breaks before", () => {
expect(processStreamingMarkdown("1. First \n2")).toBe("1. First");
});
test("should remove trailing numbers that form a new paragraph", () => {
expect(processStreamingMarkdown("1. First\n\n2")).toBe("1. First");
});
test("but should leave list items separated by two newlines", () => {
expect(processStreamingMarkdown("1. First\n\n2. S")).toBe(
"1. First\n\n2. S",
);
});
});
// TODO(drifkin):slop tests ahead, some are decent, but need to manually go
// through them as I implement
/*
describe("StreamingMarkdownContent - processStreamingMarkdown", () => {
describe("Ambiguous endings removal", () => {
it("should remove list markers at the end", () => {
expect(processStreamingMarkdown("Some text\n* ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n*")).toBe("Some text");
expect(processStreamingMarkdown("* Item 1\n- ")).toBe("* Item 1");
expect(processStreamingMarkdown("* Item 1\n-")).toBe("* Item 1");
expect(processStreamingMarkdown("Text\n+ ")).toBe("Text");
expect(processStreamingMarkdown("Text\n+")).toBe("Text");
expect(processStreamingMarkdown("1. First\n2. ")).toBe("1. First");
});
it("should remove heading markers at the end", () => {
expect(processStreamingMarkdown("Some text\n# ")).toBe("Some text");
expect(processStreamingMarkdown("Some text\n#")).toBe("Some text\n#"); // # without space is not removed
expect(processStreamingMarkdown("# Title\n## ")).toBe("# Title");
expect(processStreamingMarkdown("# Title\n##")).toBe("# Title\n##"); // ## without space is not removed
});
it("should remove ambiguous bold markers at the end", () => {
expect(processStreamingMarkdown("Text **")).toBe("Text ");
expect(processStreamingMarkdown("Some text\n**")).toBe("Some text");
});
it("should remove code block markers at the end", () => {
expect(processStreamingMarkdown("Text\n```")).toBe("Text");
expect(processStreamingMarkdown("```")).toBe("");
});
it("should remove single backtick at the end", () => {
expect(processStreamingMarkdown("Text `")).toBe("Text ");
expect(processStreamingMarkdown("`")).toBe("");
});
it("should remove single asterisk at the end", () => {
expect(processStreamingMarkdown("Text *")).toBe("Text ");
expect(processStreamingMarkdown("*")).toBe("");
});
it("should handle empty content", () => {
expect(processStreamingMarkdown("")).toBe("");
});
it("should handle single line removals correctly", () => {
expect(processStreamingMarkdown("* ")).toBe("");
expect(processStreamingMarkdown("# ")).toBe("");
expect(processStreamingMarkdown("**")).toBe("");
expect(processStreamingMarkdown("`")).toBe("");
});
it("shouldn't have this regexp capture group bug", () => {
expect(
processStreamingMarkdown("Here's a shopping list:\n*"),
).not.toContain("0*");
expect(processStreamingMarkdown("Here's a shopping list:\n*")).toBe(
"Here's a shopping list:",
);
});
});
describe("List markers", () => {
it("should preserve complete list items", () => {
expect(processStreamingMarkdown("* Complete item")).toBe(
"* Complete item",
);
expect(processStreamingMarkdown("- Another item")).toBe("- Another item");
expect(processStreamingMarkdown("+ Plus item")).toBe("+ Plus item");
expect(processStreamingMarkdown("1. Numbered item")).toBe(
"1. Numbered item",
);
});
it("should handle indented list markers", () => {
expect(processStreamingMarkdown(" * ")).toBe(" ");
expect(processStreamingMarkdown(" - ")).toBe(" ");
expect(processStreamingMarkdown("\t+ ")).toBe("\t");
});
});
describe("Heading markers", () => {
it("should preserve complete headings", () => {
expect(processStreamingMarkdown("# Complete Heading")).toBe(
"# Complete Heading",
);
expect(processStreamingMarkdown("## Subheading")).toBe("## Subheading");
expect(processStreamingMarkdown("### H3 Title")).toBe("### H3 Title");
});
it("should not affect # in other contexts", () => {
expect(processStreamingMarkdown("C# programming")).toBe("C# programming");
expect(processStreamingMarkdown("Issue #123")).toBe("Issue #123");
});
});
describe("Bold text", () => {
it("should close incomplete bold text", () => {
expect(processStreamingMarkdown("This is **bold text")).toBe(
"This is **bold text**",
);
expect(processStreamingMarkdown("Start **bold and more")).toBe(
"Start **bold and more**",
);
expect(processStreamingMarkdown("**just bold")).toBe("**just bold**");
});
it("should not affect complete bold text", () => {
expect(processStreamingMarkdown("**complete bold**")).toBe(
"**complete bold**",
);
expect(processStreamingMarkdown("Text **bold** more")).toBe(
"Text **bold** more",
);
});
it("should handle nested bold correctly", () => {
expect(processStreamingMarkdown("**bold** and **another")).toBe(
"**bold** and **another**",
);
});
});
describe("Italic text", () => {
it("should close incomplete italic text", () => {
expect(processStreamingMarkdown("This is *italic text")).toBe(
"This is *italic text*",
);
expect(processStreamingMarkdown("Start *italic and more")).toBe(
"Start *italic and more*",
);
});
it("should differentiate between list markers and italic", () => {
expect(processStreamingMarkdown("* Item\n* ")).toBe("* Item");
expect(processStreamingMarkdown("Some *italic text")).toBe(
"Some *italic text*",
);
expect(processStreamingMarkdown("*just italic")).toBe("*just italic*");
});
it("should not affect complete italic text", () => {
expect(processStreamingMarkdown("*complete italic*")).toBe(
"*complete italic*",
);
expect(processStreamingMarkdown("Text *italic* more")).toBe(
"Text *italic* more",
);
});
});
describe("Code blocks", () => {
it("should close incomplete code blocks", () => {
expect(processStreamingMarkdown("```javascript\nconst x = 42;")).toBe(
"```javascript\nconst x = 42;\n```",
);
expect(processStreamingMarkdown("```\ncode here")).toBe(
"```\ncode here\n```",
);
});
it("should not affect complete code blocks", () => {
expect(processStreamingMarkdown("```\ncode\n```")).toBe("```\ncode\n```");
expect(processStreamingMarkdown("```js\nconst x = 1;\n```")).toBe(
"```js\nconst x = 1;\n```",
);
});
it("should handle nested code blocks correctly", () => {
expect(processStreamingMarkdown("```\ncode\n```\n```python")).toBe(
"```\ncode\n```\n```python\n```",
);
});
it("should not process markdown inside code blocks", () => {
expect(processStreamingMarkdown("```\n* not a list\n**not bold**")).toBe(
"```\n* not a list\n**not bold**\n```",
);
});
});
describe("Inline code", () => {
it("should close incomplete inline code", () => {
expect(processStreamingMarkdown("This is `inline code")).toBe(
"This is `inline code`",
);
expect(processStreamingMarkdown("Use `console.log")).toBe(
"Use `console.log`",
);
});
it("should not affect complete inline code", () => {
expect(processStreamingMarkdown("`complete code`")).toBe(
"`complete code`",
);
expect(processStreamingMarkdown("Use `code` here")).toBe(
"Use `code` here",
);
});
it("should handle multiple inline codes correctly", () => {
expect(processStreamingMarkdown("`code` and `more")).toBe(
"`code` and `more`",
);
});
it("should not confuse inline code with code blocks", () => {
expect(processStreamingMarkdown("```\nblock\n```\n`inline")).toBe(
"```\nblock\n```\n`inline`",
);
});
});
describe("Complex streaming scenarios", () => {
it("should handle progressive streaming of a heading", () => {
const steps = [
{ input: "#", expected: "#" }, // # alone is not removed (needs space)
{ input: "# ", expected: "" },
{ input: "# H", expected: "# H" },
{ input: "# Hello", expected: "# Hello" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle progressive streaming of bold text", () => {
const steps = [
{ input: "*", expected: "" },
{ input: "**", expected: "" },
{ input: "**b", expected: "**b**" },
{ input: "**bold", expected: "**bold**" },
{ input: "**bold**", expected: "**bold**" },
];
steps.forEach(({ input, expected }) => {
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
it("should handle multiline content with various patterns", () => {
const multiline = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2
* `;
const expected = `# Title
This is a paragraph with **bold text** and *italic text*.
* Item 1
* Item 2`;
expect(processStreamingMarkdown(multiline)).toBe(expected);
});
it("should only fix the last line", () => {
expect(processStreamingMarkdown("# Complete\n# Another\n# ")).toBe(
"# Complete\n# Another",
);
expect(processStreamingMarkdown("* Item 1\n* Item 2\n* ")).toBe(
"* Item 1\n* Item 2",
);
});
it("should handle mixed content correctly", () => {
const input = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold`;
const expected = `# Header
This has **bold** text and *italic* text.
\`\`\`js
const x = 42;
\`\`\`
Now some \`inline code\` and **unclosed bold**`;
expect(processStreamingMarkdown(input)).toBe(expected);
});
});
describe("Edge cases with escaping", () => {
it("should handle escaped asterisks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\*not italic")).toBe(
"Text \\*not italic*",
);
});
it("should handle escaped backticks (future enhancement)", () => {
// Note: Current implementation doesn't handle escaping
// This is a known limitation - escaped characters still trigger closing
expect(processStreamingMarkdown("Text \\`not code")).toBe(
"Text \\`not code`",
);
});
});
describe("Code block edge cases", () => {
it("should handle triple backticks in the middle of lines", () => {
expect(processStreamingMarkdown("Text ``` in middle")).toBe(
"Text ``` in middle\n```",
);
expect(processStreamingMarkdown("```\nText ``` in code\nmore")).toBe(
"```\nText ``` in code\nmore\n```",
);
});
it("should properly close code blocks with language specifiers", () => {
expect(processStreamingMarkdown("```typescript")).toBe(
"```typescript\n```",
);
expect(processStreamingMarkdown("```typescript\nconst x = 1")).toBe(
"```typescript\nconst x = 1\n```",
);
});
it("should remove a completely empty partial code block", () => {
expect(processStreamingMarkdown("```\n")).toBe("");
});
});
});
*/

View File

@@ -1,123 +1,66 @@
import React from "react";
import { Streamdown, defaultRemarkPlugins } from "streamdown";
import Markdown from "react-markdown";
import remarkGfm from "remark-gfm";
import remarkMath from "remark-math";
import rehypeRaw from "rehype-raw";
import rehypeSanitize, { defaultSchema } from "rehype-sanitize";
import rehypePrismPlus from "rehype-prism-plus";
import rehypeKatex from "rehype-katex";
import remarkStreamingMarkdown, {
type LastNodeInfo,
} from "@/utils/remarkStreamingMarkdown";
import type { PluggableList } from "unified";
import remarkCitationParser from "@/utils/remarkCitationParser";
import CopyButton from "./CopyButton";
import type { BundledLanguage } from "shiki";
import { highlighter } from "@/lib/highlighter";
interface StreamingMarkdownContentProps {
content: string;
isStreaming?: boolean;
size?: "sm" | "md" | "lg";
onLastNode?: (info: LastNodeInfo) => void;
browserToolResult?: any; // TODO: proper type
}
// Helper to extract text from React nodes
const extractText = (node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
if (React.isValidElement(node)) {
const props = node.props as any;
if (props?.children) {
return extractText(props.children as React.ReactNode);
}
}
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
};
const CodeBlock = React.memo(
({ children }: React.HTMLAttributes<HTMLPreElement>) => {
// Extract code and language from children
const codeElement = children as React.ReactElement<{
className?: string;
children: React.ReactNode;
}>;
const language =
codeElement.props.className?.replace(/language-/, "") || "";
const codeText = extractText(codeElement.props.children);
({ children, className, ...props }: React.HTMLAttributes<HTMLPreElement>) => {
const extractText = React.useCallback((node: React.ReactNode): string => {
if (typeof node === "string") return node;
if (typeof node === "number") return String(node);
if (!node) return "";
// Synchronously highlight code using the pre-loaded highlighter
const tokens = React.useMemo(() => {
if (!highlighter) return null;
try {
return {
light: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-light" as any,
}),
dark: highlighter.codeToTokensBase(codeText, {
lang: language as BundledLanguage,
theme: "one-dark" as any,
}),
};
} catch (error) {
console.error("Failed to highlight code:", error);
return null;
if (React.isValidElement(node)) {
if (
node.props &&
typeof node.props === "object" &&
"children" in node.props
) {
return extractText(node.props.children as React.ReactNode);
}
}
}, [codeText, language]);
if (Array.isArray(node)) {
return node.map(extractText).join("");
}
return "";
}, []);
const language = className?.replace(/language-/, "") || "";
return (
<div className="relative bg-neutral-100 dark:bg-neutral-800 rounded-2xl overflow-hidden my-6">
<div className="flex select-none">
{language && (
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
)}
<div className="flex justify-between select-none">
<div className="text-[13px] text-neutral-500 dark:text-neutral-400 font-mono px-4 py-2">
{language}
</div>
<CopyButton
content={codeText}
content={extractText(children)}
showLabels={true}
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 ml-auto"
className="copy-button text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800"
/>
</div>
{/* Light mode */}
<pre className="dark:hidden m-0 bg-neutral-100 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.light
? tokens.light.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.light.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
</pre>
{/* Dark mode */}
<pre className="hidden dark:block m-0 bg-neutral-800 text-sm overflow-x-auto p-4">
<code className="font-mono text-sm">
{tokens?.dark
? tokens.dark.map((line: any, i: number) => (
<React.Fragment key={i}>
{line.map((token: any, j: number) => (
<span
key={j}
style={{
color: token.color,
}}
>
{token.content}
</span>
))}
{i < tokens.dark.length - 1 && "\n"}
</React.Fragment>
))
: codeText}
</code>
<pre className={className} {...props}>
{children}
</pre>
</div>
);
@@ -125,19 +68,65 @@ const CodeBlock = React.memo(
);
const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
React.memo(({ content, isStreaming = false, size, browserToolResult }) => {
// Build the remark plugins array - keep default GFM and Math, add citations
const remarkPlugins = React.useMemo(() => {
return [
defaultRemarkPlugins.gfm,
defaultRemarkPlugins.math,
remarkCitationParser,
];
}, []);
React.memo(
({ content, isStreaming = false, size, onLastNode, browserToolResult }) => {
// Build the remark plugins array
const remarkPlugins = React.useMemo(() => {
const plugins: PluggableList = [
remarkGfm,
[remarkMath, { singleDollarTextMath: false }],
remarkCitationParser,
];
return (
<div
className={`
// Add streaming plugin when in streaming mode
if (isStreaming) {
plugins.push([remarkStreamingMarkdown, { debug: true, onLastNode }]);
}
return plugins;
}, [isStreaming, onLastNode]);
// Create a custom sanitization schema that allows math elements
const sanitizeSchema = React.useMemo(() => {
return {
...defaultSchema,
attributes: {
...defaultSchema.attributes,
span: [
...(defaultSchema.attributes?.span || []),
["className", /^katex/],
],
div: [
...(defaultSchema.attributes?.div || []),
["className", /^katex/],
],
"ol-citation": ["cursor", "start", "end"],
},
tagNames: [
...(defaultSchema.tagNames || []),
"math",
"mrow",
"mi",
"mo",
"mn",
"msup",
"msub",
"mfrac",
"mover",
"munder",
"msqrt",
"mroot",
"merror",
"mspace",
"mpadded",
"ol-citation",
],
};
}, []);
return (
<div
className={`
max-w-full
${size === "sm" ? "prose-sm" : size === "lg" ? "prose-lg" : ""}
prose
@@ -155,27 +144,7 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
prose-pre:my-0
prose-pre:max-w-full
prose-pre:pt-1
[&_table]:border-collapse
[&_table]:w-full
[&_table]:border
[&_table]:border-neutral-200
[&_table]:rounded-lg
[&_table]:overflow-hidden
[&_th]:px-3
[&_th]:py-2
[&_th]:text-left
[&_th]:font-semibold
[&_th]:border-b
[&_th]:border-r
[&_th]:border-neutral-200
[&_th:last-child]:border-r-0
[&_td]:px-3
[&_td]:py-2
[&_td]:border-r
[&_td]:border-neutral-200
[&_td:last-child]:border-r-0
[&_tbody_tr:not(:last-child)_td]:border-b
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:text-neutral-700
[&_code:not(pre_code)]:bg-neutral-100
[&_code:not(pre_code)]:font-normal
[&_code:not(pre_code)]:px-1.5
@@ -191,10 +160,6 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-strong:text-neutral-200
dark:prose-pre:text-neutral-200
dark:prose:pre:text-neutral-200
dark:[&_table]:border-neutral-700
dark:[&_thead]:bg-neutral-800
dark:[&_th]:border-neutral-700
dark:[&_td]:border-neutral-700
dark:[&_code:not(pre_code)]:text-neutral-200
dark:[&_code:not(pre_code)]:bg-neutral-800
dark:[&_code:not(pre_code)]:font-normal
@@ -202,86 +167,104 @@ const StreamingMarkdownContent: React.FC<StreamingMarkdownContentProps> =
dark:prose-li:marker:text-neutral-300
break-words
`}
>
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
<Streamdown
parseIncompleteMarkdown={isStreaming}
isAnimating={isStreaming}
remarkPlugins={remarkPlugins}
controls={false}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table
{...props}
className="border-collapse w-full border border-neutral-200 dark:border-neutral-700 rounded-lg overflow-hidden"
>
{children}
</table>
</div>
),
// @ts-expect-error: custom citation type
"ol-citation": ({
cursor,
}: {
cursor: number;
start: number;
end: number;
}) => {
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring("search_results_".length);
return `Search: ${searchTerm}`;
}
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
return citationElement;
},
}}
<StreamingMarkdownErrorBoundary
content={content}
isStreaming={isStreaming}
>
{content}
</Streamdown>
</StreamingMarkdownErrorBoundary>
</div>
);
});
<Markdown
remarkPlugins={remarkPlugins}
rehypePlugins={
[
[rehypeRaw, { allowDangerousHtml: true }],
[rehypeSanitize, sanitizeSchema],
[rehypePrismPlus, { ignoreMissing: true }],
[
rehypeKatex,
{
errorColor: "#000000", // Black instead of red for errors
strict: false, // Be more lenient with parsing
throwOnError: false,
},
],
] as PluggableList
}
components={{
pre: CodeBlock,
table: ({
children,
...props
}: React.HTMLAttributes<HTMLTableElement>) => (
<div className="overflow-x-auto max-w-full">
<table {...props}>{children}</table>
</div>
),
// @ts-expect-error: custom type
"ol-citation": ({
cursor,
// start,
// end,
}: {
cursor: number;
start: number;
end: number;
}) => {
// Check if we have a page_stack and if the cursor is valid
const pageStack = browserToolResult?.page_stack;
const hasValidPage = pageStack && cursor < pageStack.length;
const pageUrl = hasValidPage ? pageStack[cursor] : null;
// Extract a readable title from the URL if possible
const getPageTitle = (url: string) => {
if (url.startsWith("search_results_")) {
const searchTerm = url.substring(
"search_results_".length,
);
return `Search: ${searchTerm}`;
}
// For regular URLs, try to extract domain or use full URL
try {
const urlObj = new URL(url);
return urlObj.hostname;
} catch {
// If not a valid URL, return as is
return url;
}
};
const citationElement = (
<span className="text-xs text-neutral-500 dark:text-neutral-400 bg-neutral-100 dark:bg-neutral-800 rounded-full px-2 py-1 ml-1">
[{cursor}]
</span>
);
// If we have a valid page URL, wrap in a link
if (pageUrl && pageUrl.startsWith("http")) {
return (
<a
href={pageUrl}
target="_blank"
rel="noopener noreferrer"
className="inline-flex items-center hover:opacity-80 transition-opacity no-underline"
title={getPageTitle(pageUrl)}
>
{citationElement}
</a>
);
}
// Otherwise, just return the citation without a link
return citationElement;
},
}}
>
{content}
</Markdown>
</StreamingMarkdownErrorBoundary>
</div>
);
},
);
interface StreamingMarkdownErrorBoundaryProps {
content: string;

View File

@@ -50,33 +50,21 @@ export default function Thinking({
// Position content to show bottom when collapsed
useEffect(() => {
if (isCollapsed && contentRef.current && wrapperRef.current) {
requestAnimationFrame(() => {
if (!contentRef.current || !wrapperRef.current) return;
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
});
const contentHeight = contentRef.current.scrollHeight;
const wrapperHeight = wrapperRef.current.clientHeight;
if (contentHeight > wrapperHeight) {
const translateY = -(contentHeight - wrapperHeight);
contentRef.current.style.transform = `translateY(${translateY}px)`;
setHasOverflow(true);
} else {
setHasOverflow(false);
}
} else if (contentRef.current) {
contentRef.current.style.transform = "translateY(0)";
setHasOverflow(false);
}
}, [thinking, isCollapsed]);
useEffect(() => {
if (activelyThinking && wrapperRef.current && !isCollapsed) {
// When expanded and actively thinking, scroll to bottom
wrapperRef.current.scrollTop = wrapperRef.current.scrollHeight;
}
}, [thinking, activelyThinking, isCollapsed]);
const handleToggle = () => {
setIsCollapsed(!isCollapsed);
setHasUserInteracted(true);
@@ -85,9 +73,8 @@ export default function Thinking({
// Calculate max height for smooth animations
const getMaxHeight = () => {
if (isCollapsed) {
return finishedThinking ? "0px" : "12rem";
return finishedThinking ? "0px" : "12rem"; // 8rem = 128px (same as max-h-32)
}
// When expanded, use the content height or grow naturally
return contentHeight ? `${contentHeight}px` : "none";
};
@@ -144,11 +131,10 @@ export default function Thinking({
</div>
<div
ref={wrapperRef}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2
${isCollapsed ? "overflow-hidden" : "overflow-y-auto"}`}
className={`text-xs text-neutral-500 dark:text-neutral-500 rounded-md overflow-hidden
transition-[max-height,opacity] duration-300 ease-in-out relative ml-6 mt-2`}
style={{
maxHeight: isCollapsed ? getMaxHeight() : undefined,
maxHeight: getMaxHeight(),
opacity: isCollapsed && finishedThinking ? 0 : 1,
}}
>

View File

@@ -7,7 +7,6 @@ import { createQueryBatcher } from "./useQueryBatcher";
import { useRefetchModels } from "./useModels";
import { useStreamingContext } from "@/contexts/StreamingContext";
import { useSettings } from "./useSettings";
import { getModelCapabilities } from "@/api";
export const useChats = () => {
return useQuery({
@@ -607,24 +606,6 @@ export const useSendMessage = (chatId: string) => {
queryClient.setQueryData(["staleModels"], newStaleMap);
queryClient.invalidateQueries({ queryKey: ["models"] });
// Fetch fresh capabilities for the downloaded model
getModelCapabilities(selectedModel.model)
.then((capabilities) => {
queryClient.setQueryData(
["modelCapabilities", selectedModel.model],
capabilities,
);
})
.catch((error) => {
console.error(
"Failed to fetch capabilities after download:",
error,
);
queryClient.invalidateQueries({
queryKey: ["modelCapabilities", selectedModel.model],
});
});
}
break;
}

View File

@@ -0,0 +1,114 @@
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { useState } from "react";
import { pullModel } from "@/api";
import { useSelectedModel } from "./useSelectedModel";
import { useSettings } from "./useSettings";
interface DownloadProgress {
status: string;
digest?: string;
total?: number;
completed?: number;
done?: boolean;
}
export function useDownloadModel(chatId?: string) {
const queryClient = useQueryClient();
const { selectedModel } = useSelectedModel(chatId);
const { setSettings } = useSettings();
const [downloadProgress, setDownloadProgress] =
useState<DownloadProgress | null>(null);
const [abortController, setAbortController] =
useState<AbortController | null>(null);
const [downloadingChatIds, setDownloadingChatIds] = useState<Set<string>>(
new Set(),
);
const mutation = useMutation({
mutationFn: async (modelName: string) => {
const controller = new AbortController();
setAbortController(controller);
setDownloadProgress({ status: "Starting download..." });
if (chatId) {
setDownloadingChatIds((prev) => new Set(prev).add(chatId));
}
try {
for await (const progress of pullModel(modelName, controller.signal)) {
setDownloadProgress(progress);
if (progress.status === "success") {
// Update selected model to indicate it's now available locally
if (selectedModel && selectedModel.model === modelName) {
setSettings({ SelectedModel: modelName });
}
// Invalidate models query to refresh the list
await queryClient.invalidateQueries({ queryKey: ["models"] });
break;
}
}
} finally {
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
},
onSuccess: () => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
},
onError: (error: Error) => {
const status =
error.name === "AbortError" ? "Download cancelled" : "Download failed";
setDownloadProgress({ status, done: true });
// Clear error message after delay
const delay = error.name === "AbortError" ? 1500 : 3000;
setTimeout(() => {
setDownloadProgress(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}, delay);
},
});
const cancelDownload = () => {
if (abortController) {
abortController.abort();
setAbortController(null);
if (chatId) {
setDownloadingChatIds((prev) => {
const newSet = new Set(prev);
newSet.delete(chatId);
return newSet;
});
}
}
};
return {
downloadModel: mutation.mutate,
isDownloading:
mutation.isPending && chatId ? downloadingChatIds.has(chatId) : false,
downloadProgress:
chatId && downloadingChatIds.has(chatId) ? downloadProgress : null,
error: mutation.error,
cancelDownload,
};
}

View File

@@ -1,20 +1,29 @@
import { useQuery, useMutation, useQueryClient } from "@tanstack/react-query";
import { useEffect, useState } from "react";
import { fetchUser, fetchConnectUrl, disconnectUser } from "@/api";
export function useUser() {
const queryClient = useQueryClient();
const [initialDataLoaded, setInitialDataLoaded] = useState(false);
// Wait for initial data to be loaded
useEffect(() => {
const initialPromise = window.__initialUserDataPromise;
if (initialPromise) {
initialPromise.finally(() => {
setInitialDataLoaded(true);
});
} else {
setInitialDataLoaded(true);
}
}, []);
const userQuery = useQuery({
queryKey: ["user"],
queryFn: async () => {
const result = await fetchUser();
return result;
},
queryFn: () => fetchUser(),
staleTime: 5 * 60 * 1000, // Consider data stale after 5 minutes
gcTime: 10 * 60 * 1000, // Keep in cache for 10 minutes
retry: 10,
retryDelay: (attemptIndex) => Math.min(500 * attemptIndex, 2000),
refetchOnMount: true, // Always fetch when component mounts
initialData: null, // Start with null to prevent flashing
});
// Mutation to refresh user data
@@ -40,15 +49,14 @@ export function useUser() {
},
});
const isLoading = userQuery.isLoading || userQuery.isFetching;
const isAuthenticated = Boolean(userQuery.data?.name);
return {
user: userQuery.data,
isLoading,
isLoading:
!initialDataLoaded ||
(userQuery.isLoading && userQuery.data === undefined), // Show loading until initial data is loaded
isError: userQuery.isError,
error: userQuery.error,
isAuthenticated,
isAuthenticated: Boolean(userQuery.data?.name),
refreshUser: refreshUser.mutate,
isRefreshing: refreshUser.isPending,
refetchUser: userQuery.refetch,

View File

@@ -16,6 +16,793 @@
--text-color: #ffffff;
}
}
@media (prefers-color-scheme: light) {
.prose {
/**
* One Light theme for prism.js
* Based on Atom's One Light theme: https://github.com/atom/atom/tree/master/packages/one-light-syntax
*/
/**
* One Light colours (accurate as of commit eb064bf on 19 Feb 2021)
* From colors.less
* --mono-1: hsl(230, 8%, 24%);
* --mono-2: hsl(230, 6%, 44%);
* --mono-3: hsl(230, 4%, 64%)
* --hue-1: hsl(198, 99%, 37%);
* --hue-2: hsl(221, 87%, 60%);
* --hue-3: hsl(301, 63%, 40%);
* --hue-4: hsl(119, 34%, 47%);
* --hue-5: hsl(5, 74%, 59%);
* --hue-5-2: hsl(344, 84%, 43%);
* --hue-6: hsl(35, 99%, 36%);
* --hue-6-2: hsl(35, 99%, 40%);
* --syntax-fg: hsl(230, 8%, 24%);
* --syntax-bg: hsl(230, 1%, 98%);
* --syntax-gutter: hsl(230, 1%, 62%);
* --syntax-guide: hsla(230, 8%, 24%, 0.2);
* --syntax-accent: hsl(230, 100%, 66%);
* From syntax-variables.less
* --syntax-selection-color: hsl(230, 1%, 90%);
* --syntax-gutter-background-color-selected: hsl(230, 1%, 90%);
* --syntax-cursor-line: hsla(230, 8%, 24%, 0.05);
*/
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(230, 4%, 64%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(230, 8%, 24%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(35, 99%, 36%);
}
.token.keyword {
color: hsl(301, 63%, 40%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(5, 74%, 59%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(119, 34%, 47%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(221, 87%, 60%);
}
.token.url {
color: hsl(198, 99%, 37%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(230, 8%, 24%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(5, 74%, 59%);
}
.language-css .token.property {
color: hsl(230, 8%, 24%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(198, 99%, 37%);
}
.language-css .token.url > .token.string.url {
color: hsl(119, 34%, 47%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(301, 63%, 40%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(301, 63%, 40%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(344, 84%, 43%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(230, 8%, 24%);
}
.language-json .token.null.keyword {
color: hsl(35, 99%, 36%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(230, 8%, 24%);
}
.language-markdown .token.url > .token.content {
color: hsl(221, 87%, 60%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(198, 99%, 37%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(230, 4%, 64%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(119, 34%, 47%);
}
.language-markdown .token.bold .token.content {
color: hsl(35, 99%, 36%);
}
.language-markdown .token.italic .token.content {
color: hsl(301, 63%, 40%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(5, 74%, 59%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(230, 8%, 24%, 0.2);
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(230, 1%, 90%);
color: hsl(230, 6%, 44%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(230, 1%, 78%); /* custom: darken(--syntax-bg, 20%) */
color: hsl(230, 8%, 24%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(230, 8%, 24%, 0.05);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(230, 1%, 90%);
color: hsl(230, 8%, 24%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(230, 8%, 24%, 0.05);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(230, 8%, 24%, 0.2);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(230, 1%, 62%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(5, 74%, 59%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(119, 34%, 47%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(221, 87%, 60%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(301, 63%, 40%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-light-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(0, 0, 95%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(0, 0, 95%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(0, 0, 95%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(0, 0%, 100%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(230, 8%, 24%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(230, 8%, 24%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
@media (prefers-color-scheme: dark) {
.prose {
.token.comment,
.token.prolog,
.token.cdata {
color: hsl(220, 10%, 40%);
}
.token.doctype,
.token.punctuation,
.token.entity {
color: hsl(220, 14%, 71%);
}
.token.attr-name,
.token.class-name,
.token.boolean,
.token.constant,
.token.number,
.token.atrule {
color: hsl(29, 54%, 61%);
}
.token.keyword {
color: hsl(286, 60%, 67%);
}
.token.property,
.token.tag,
.token.symbol,
.token.deleted,
.token.important {
color: hsl(355, 65%, 65%);
}
.token.selector,
.token.string,
.token.char,
.token.builtin,
.token.inserted,
.token.regex,
.token.attr-value,
.token.attr-value > .token.punctuation {
color: hsl(95, 38%, 62%);
}
.token.variable,
.token.operator,
.token.function {
color: hsl(207, 82%, 66%);
}
.token.url {
color: hsl(187, 47%, 55%);
}
/* HTML overrides */
.token.attr-value > .token.punctuation.attr-equals,
.token.special-attr > .token.attr-value > .token.value.css {
color: hsl(220, 14%, 71%);
}
/* CSS overrides */
.language-css .token.selector {
color: hsl(355, 65%, 65%);
}
.language-css .token.property {
color: hsl(220, 14%, 71%);
}
.language-css .token.function,
.language-css .token.url > .token.function {
color: hsl(187, 47%, 55%);
}
.language-css .token.url > .token.string.url {
color: hsl(95, 38%, 62%);
}
.language-css .token.important,
.language-css .token.atrule .token.rule {
color: hsl(286, 60%, 67%);
}
/* JS overrides */
.language-javascript .token.operator {
color: hsl(286, 60%, 67%);
}
.language-javascript
.token.template-string
> .token.interpolation
> .token.interpolation-punctuation.punctuation {
color: hsl(5, 48%, 51%);
}
/* JSON overrides */
.language-json .token.operator {
color: hsl(220, 14%, 71%);
}
.language-json .token.null.keyword {
color: hsl(29, 54%, 61%);
}
/* MD overrides */
.language-markdown .token.url,
.language-markdown .token.url > .token.operator,
.language-markdown .token.url-reference.url > .token.string {
color: hsl(220, 14%, 71%);
}
.language-markdown .token.url > .token.content {
color: hsl(207, 82%, 66%);
}
.language-markdown .token.url > .token.url,
.language-markdown .token.url-reference.url {
color: hsl(187, 47%, 55%);
}
.language-markdown .token.blockquote.punctuation,
.language-markdown .token.hr.punctuation {
color: hsl(220, 10%, 40%);
font-style: italic;
}
.language-markdown .token.code-snippet {
color: hsl(95, 38%, 62%);
}
.language-markdown .token.bold .token.content {
color: hsl(29, 54%, 61%);
}
.language-markdown .token.italic .token.content {
color: hsl(286, 60%, 67%);
}
.language-markdown .token.strike .token.content,
.language-markdown .token.strike .token.punctuation,
.language-markdown .token.list.punctuation,
.language-markdown .token.title.important > .token.punctuation {
color: hsl(355, 65%, 65%);
}
/* General */
.token.bold {
font-weight: bold;
}
.token.comment,
.token.italic {
font-style: italic;
}
.token.entity {
cursor: help;
}
.token.namespace {
opacity: 0.8;
}
/* Plugin overrides */
/* Selectors should have higher specificity than those in the plugins' default stylesheets */
/* Show Invisibles plugin overrides */
.token.token.tab:not(:empty):before,
.token.token.cr:before,
.token.token.lf:before,
.token.token.space:before {
color: hsla(220, 14%, 71%, 0.15);
text-shadow: none;
}
/* Toolbar plugin overrides */
/* Space out all buttons and move them away from the right edge of the code block */
div.code-toolbar > .toolbar.toolbar > .toolbar-item {
margin-right: 0.4em;
}
/* Styling the buttons */
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span {
background: hsl(220, 13%, 26%);
color: hsl(220, 9%, 55%);
padding: 0.1em 0.4em;
border-radius: 0.3em;
}
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > button:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > a:focus,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:hover,
div.code-toolbar > .toolbar.toolbar > .toolbar-item > span:focus {
background: hsl(220, 13%, 28%);
color: hsl(220, 14%, 71%);
}
/* Line Highlight plugin overrides */
/* The highlighted line itself */
.line-highlight.line-highlight {
background: hsla(220, 100%, 80%, 0.04);
}
/* Default line numbers in Line Highlight plugin */
.line-highlight.line-highlight:before,
.line-highlight.line-highlight[data-end]:after {
background: hsl(220, 13%, 26%);
color: hsl(220, 14%, 71%);
padding: 0.1em 0.6em;
border-radius: 0.3em;
box-shadow: 0 2px 0 0 rgba(0, 0, 0, 0.2); /* same as Toolbar plugin default */
}
/* Hovering over a linkable line number (in the gutter area) */
/* Requires Line Numbers plugin as well */
pre[id].linkable-line-numbers.linkable-line-numbers
span.line-numbers-rows
> span:hover:before {
background-color: hsla(220, 100%, 80%, 0.04);
}
/* Line Numbers and Command Line plugins overrides */
/* Line separating gutter from coding area */
.line-numbers.line-numbers .line-numbers-rows,
.command-line .command-line-prompt {
border-right-color: hsla(220, 14%, 71%, 0.15);
}
/* Stuff in the gutter */
.line-numbers .line-numbers-rows > span:before,
.command-line .command-line-prompt > span:before {
color: hsl(220, 14%, 45%);
}
/* Match Braces plugin overrides */
/* Note: Outline colour is inherited from the braces */
.rainbow-braces .token.token.punctuation.brace-level-1,
.rainbow-braces .token.token.punctuation.brace-level-5,
.rainbow-braces .token.token.punctuation.brace-level-9 {
color: hsl(355, 65%, 65%);
}
.rainbow-braces .token.token.punctuation.brace-level-2,
.rainbow-braces .token.token.punctuation.brace-level-6,
.rainbow-braces .token.token.punctuation.brace-level-10 {
color: hsl(95, 38%, 62%);
}
.rainbow-braces .token.token.punctuation.brace-level-3,
.rainbow-braces .token.token.punctuation.brace-level-7,
.rainbow-braces .token.token.punctuation.brace-level-11 {
color: hsl(207, 82%, 66%);
}
.rainbow-braces .token.token.punctuation.brace-level-4,
.rainbow-braces .token.token.punctuation.brace-level-8,
.rainbow-braces .token.token.punctuation.brace-level-12 {
color: hsl(286, 60%, 67%);
}
/* Diff Highlight plugin overrides */
/* Taken from https://github.com/atom/github/blob/master/styles/variables.less */
pre.diff-highlight > code .token.token.deleted:not(.prefix),
pre > code.diff-highlight .token.token.deleted:not(.prefix) {
background-color: hsla(353, 100%, 66%, 0.15);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.deleted:not(.prefix)
*::-moz-selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.deleted:not(.prefix)
*::-moz-selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.deleted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.deleted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.deleted:not(.prefix) *::selection {
background-color: hsla(353, 95%, 66%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix),
pre > code.diff-highlight .token.token.inserted:not(.prefix) {
background-color: hsla(137, 100%, 55%, 0.15);
}
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)::-moz-selection,
pre.diff-highlight
> code
.token.token.inserted:not(.prefix)
*::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)::-moz-selection,
pre
> code.diff-highlight
.token.token.inserted:not(.prefix)
*::-moz-selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
pre.diff-highlight > code .token.token.inserted:not(.prefix)::selection,
pre.diff-highlight > code .token.token.inserted:not(.prefix) *::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix)::selection,
pre > code.diff-highlight .token.token.inserted:not(.prefix) *::selection {
background-color: hsla(135, 73%, 55%, 0.25);
}
/* Previewers plugin overrides */
/* Based on https://github.com/atom-community/atom-ide-datatip/blob/master/styles/atom-ide-datatips.less and https://github.com/atom/atom/blob/master/packages/one-dark-ui */
/* Border around popup */
.prism-previewer.prism-previewer:before,
.prism-previewer-gradient.prism-previewer-gradient div {
border-color: hsl(224, 13%, 17%);
}
/* Angle and time should remain as circles and are hence not included */
.prism-previewer-color.prism-previewer-color:before,
.prism-previewer-gradient.prism-previewer-gradient div,
.prism-previewer-easing.prism-previewer-easing:before {
border-radius: 0.3em;
}
/* Triangles pointing to the code */
.prism-previewer.prism-previewer:after {
border-top-color: hsl(224, 13%, 17%);
}
.prism-previewer-flipped.prism-previewer-flipped.after {
border-bottom-color: hsl(224, 13%, 17%);
}
/* Background colour within the popup */
.prism-previewer-angle.prism-previewer-angle:before,
.prism-previewer-time.prism-previewer-time:before,
.prism-previewer-easing.prism-previewer-easing {
background: hsl(219, 13%, 22%);
}
/* For angle, this is the positive area (eg. 90deg will display one quadrant in this colour) */
/* For time, this is the alternate colour */
.prism-previewer-angle.prism-previewer-angle circle,
.prism-previewer-time.prism-previewer-time circle {
stroke: hsl(220, 14%, 71%);
stroke-opacity: 1;
}
/* Stroke colours of the handle, direction point, and vector itself */
.prism-previewer-easing.prism-previewer-easing circle,
.prism-previewer-easing.prism-previewer-easing path,
.prism-previewer-easing.prism-previewer-easing line {
stroke: hsl(220, 14%, 71%);
}
/* Fill colour of the handle */
.prism-previewer-easing.prism-previewer-easing circle {
fill: transparent;
}
}
}
.prose pre {
contain: layout style;
}
/* Or more aggressively */
.prose pre code {
contain: layout style paint;
}
/* messaging-style typing indicator animation */
@keyframes typing {

View File

@@ -1,13 +0,0 @@
// API configuration
const DEV_API_URL = "http://127.0.0.1:3001";
// Base URL for fetch API calls (can be relative in production)
export const API_BASE = import.meta.env.DEV ? DEV_API_URL : "";
// Full host URL for Ollama client (needs full origin in production)
export const OLLAMA_HOST = import.meta.env.DEV
? DEV_API_URL
: window.location.origin;
export const OLLAMA_DOT_COM =
import.meta.env.VITE_OLLAMA_DOT_COM_URL || "https://ollama.com";

View File

@@ -1,156 +0,0 @@
import { createHighlighter } from "shiki";
import type { ThemeRegistration } from "shiki";
const oneLightTheme: ThemeRegistration = {
name: "one-light",
type: "light",
colors: {
"editor.background": "#fafafa",
"editor.foreground": "#383a42",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#a0a1a7" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#a626a4" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#50a14f" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#4078f2" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#c18401" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e45649" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#c18401" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#a626a4" },
},
{ scope: ["punctuation"], settings: { foreground: "#383a42" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e45649", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#c18401", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#a626a4", fontStyle: "italic" },
},
],
};
const oneDarkTheme: ThemeRegistration = {
name: "one-dark",
type: "dark",
colors: {
"editor.background": "#282c34",
"editor.foreground": "#abb2bf",
},
tokenColors: [
{
scope: ["comment", "punctuation.definition.comment"],
settings: { foreground: "#5c6370" },
},
{
scope: ["keyword", "storage.type", "storage.modifier"],
settings: { foreground: "#c678dd" },
},
{ scope: ["string", "string.quoted"], settings: { foreground: "#98c379" } },
{
scope: ["function", "entity.name.function", "support.function"],
settings: { foreground: "#61afef" },
},
{
scope: [
"constant.numeric",
"constant.language",
"constant.character",
"number",
],
settings: { foreground: "#d19a66" },
},
{
scope: ["variable", "support.variable"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.name.tag", "entity.name.type", "entity.name.class"],
settings: { foreground: "#e06c75" },
},
{
scope: ["entity.other.attribute-name"],
settings: { foreground: "#d19a66" },
},
{
scope: ["keyword.operator", "operator"],
settings: { foreground: "#c678dd" },
},
{ scope: ["punctuation"], settings: { foreground: "#abb2bf" } },
{
scope: ["markup.heading"],
settings: { foreground: "#e06c75", fontStyle: "bold" },
},
{
scope: ["markup.bold"],
settings: { foreground: "#d19a66", fontStyle: "bold" },
},
{
scope: ["markup.italic"],
settings: { foreground: "#c678dd", fontStyle: "italic" },
},
],
};
export let highlighter: Awaited<ReturnType<typeof createHighlighter>> | null =
null;
export const highlighterPromise = createHighlighter({
themes: [oneLightTheme, oneDarkTheme],
langs: [
"javascript",
"typescript",
"python",
"bash",
"shell",
"json",
"html",
"css",
"tsx",
"jsx",
"go",
"rust",
"java",
"c",
"cpp",
"sql",
"yaml",
"markdown",
],
}).then((h) => {
highlighter = h;
return h;
});

View File

@@ -1,5 +1,4 @@
import { Ollama } from "ollama/browser";
import { OLLAMA_HOST } from "./config";
let _ollamaClient: Ollama | null = null;
@@ -7,7 +6,7 @@ export const ollamaClient = new Proxy({} as Ollama, {
get(_target, prop) {
if (!_ollamaClient) {
_ollamaClient = new Ollama({
host: OLLAMA_HOST,
host: window.location.origin,
});
}
const value = _ollamaClient[prop as keyof Ollama];

View File

@@ -5,6 +5,13 @@ import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import { routeTree } from "./routeTree.gen";
import { fetchUser } from "./api";
import { StreamingProvider } from "./contexts/StreamingContext";
import { User } from "@/gotypes";
declare global {
interface Window {
__initialUserDataPromise?: Promise<User | null>;
}
}
const queryClient = new QueryClient({
defaultOptions: {
@@ -17,11 +24,27 @@ const queryClient = new QueryClient({
},
});
fetchUser().then((userData) => {
if (userData) {
// Track initial user data fetch
let initialUserDataPromise: Promise<User | null> | null = null;
// Initialize user data on app startup
const initializeUserData = async () => {
try {
const userData = await fetchUser();
queryClient.setQueryData(["user"], userData);
return userData;
} catch (error) {
console.error("Error initializing user data:", error);
queryClient.setQueryData(["user"], null);
return null;
}
});
};
// Start initialization immediately and track the promise
initialUserDataPromise = initializeUserData();
// Export the promise so hooks can await it
window.__initialUserDataPromise = initialUserDataPromise;
const router = createRouter({
routeTree,

View File

@@ -1,97 +0,0 @@
import { describe, it, expect } from "vitest";
import { IMAGE_EXTENSIONS, validateFile } from "./fileValidation";
describe("fileValidation", () => {
describe("IMAGE_EXTENSIONS", () => {
it("should include all supported image formats including WebP", () => {
expect(IMAGE_EXTENSIONS).toContain("png");
expect(IMAGE_EXTENSIONS).toContain("jpg");
expect(IMAGE_EXTENSIONS).toContain("jpeg");
expect(IMAGE_EXTENSIONS).toContain("webp");
});
});
describe("validateFile", () => {
const createMockFile = (
name: string,
size: number,
type: string,
): File => {
const blob = new Blob(["test content"], { type });
return new File([blob], name, { type });
};
it("should accept WebP images when vision capability is enabled", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should reject WebP images when vision capability is disabled", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: false,
});
expect(result.valid).toBe(false);
expect(result.error).toBe("This model does not support images");
});
it("should accept PNG images when vision capability is enabled", () => {
const file = createMockFile("test.png", 1024, "image/png");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should accept JPEG images when vision capability is enabled", () => {
const file = createMockFile("test.jpg", 1024, "image/jpeg");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(true);
});
it("should reject files that are too large", () => {
// Create a file with size property set correctly
const largeSize = 11 * 1024 * 1024; // 11MB
const content = new Uint8Array(largeSize);
const blob = new Blob([content], { type: "image/webp" });
const file = new File([blob], "large.webp", { type: "image/webp" });
const result = validateFile(file, {
hasVisionCapability: true,
maxFileSize: 10, // 10MB limit
});
expect(result.valid).toBe(false);
expect(result.error).toBe("File too large");
});
it("should reject unsupported file types", () => {
const file = createMockFile("test.xyz", 1024, "application/xyz");
const result = validateFile(file, {
hasVisionCapability: true,
});
expect(result.valid).toBe(false);
expect(result.error).toBe("File type not supported");
});
it("should respect custom validators", () => {
const file = createMockFile("test.webp", 1024, "image/webp");
const result = validateFile(file, {
hasVisionCapability: true,
customValidator: () => ({
valid: false,
error: "Custom error",
}),
});
expect(result.valid).toBe(false);
expect(result.error).toBe("Custom error");
});
});
// Note: processFiles tests are skipped because FileReader is not available in the Node.js test environment
// These functions are tested in browser environment via integration tests
});

View File

@@ -41,7 +41,7 @@ export const TEXT_FILE_EXTENSIONS = [
"rtf",
];
export const IMAGE_EXTENSIONS = ["png", "jpg", "jpeg", "webp"];
export const IMAGE_EXTENSIONS = ["png", "jpg", "jpeg"];
export interface FileValidationOptions {
maxFileSize?: number; // in MB

View File

@@ -0,0 +1,24 @@
import { remark } from "remark";
import remarkStringify from "remark-stringify";
import remarkStreamingMarkdown from "./remarkStreamingMarkdown";
/**
* Process markdown content for streaming display using the remark plugin.
* This is primarily used for testing the remark plugin with string inputs/outputs.
*/
export function processStreamingMarkdown(content: string): string {
if (!content) return content;
const result = remark()
.use(remarkStreamingMarkdown, { debug: false })
.use(remarkStringify)
.processSync(content);
// remove trailing newline to keep tests cleaner
let output = result.toString();
if (output.endsWith("\n")) {
output = output.slice(0, -1);
}
return output;
}

View File

@@ -0,0 +1,447 @@
import { parents, type Proxy } from "unist-util-parents";
import type { Plugin } from "unified";
import type {
Emphasis,
Node,
Parent,
Root,
RootContent,
Text,
Strong,
PhrasingContent,
Paragraph,
} from "mdast";
import { u } from "unist-builder";
declare module "unist" {
interface Node {
/** Added by `unist-util-parents` (or your own walk). */
parent?: Proxy & Parent;
}
}
// interface SimpleTextRule {
// pattern: RegExp;
// transform: (matches: RegExpExecArray[], lastNode: Proxy) => void;
// }
// const simpleTextRules: SimpleTextRule[] = [
// // TODO(drifkin): generalize this for `__`/`_`/`~~`/`~` etc.
// {
// pattern: /(\*\*)(?=\S|$)/g,
// transform: (matchesIterator, lastNode) => {
// const textNode = lastNode.node as Text;
// const matches = [...matchesIterator];
// const lastMatch = matches[matches.length - 1];
// const origValue = textNode.value;
// const start = lastMatch.index;
// const sep = lastMatch[1];
// const before = origValue.slice(0, start);
// const after = origValue.slice(start + sep.length);
// if (lastNode.parent) {
// const index = (lastNode.parent.node as Parent).children.indexOf(
// lastNode.node as RootContent,
// );
// const shouldRemove = before.length === 0;
// if (!shouldRemove) {
// textNode.value = before;
// }
// const newNode = u("strong", {
// children: [u("text", { value: after })],
// });
// (lastNode.parent.node as Parent).children.splice(
// index + (shouldRemove ? 0 : 1),
// shouldRemove ? 1 : 0,
// newNode,
// );
// }
// },
// },
// ];
interface Options {
debug?: boolean;
onLastNode?: (info: LastNodeInfo) => void;
}
export interface LastNodeInfo {
path: string[];
type: string;
value?: string;
lastChars?: string;
fullNode: Node;
}
/**
* Removes `child` from `parent` in-place.
* @returns `true` if the child was found and removed; `false` otherwise.
*/
export function removeChildFromParent(
child: RootContent,
parent: Node,
): boolean {
if (!isParent(parent)) return false; // parent isnt a Parent → nothing to do
const idx = parent.children.indexOf(child);
if (idx < 0) return false; // not a child → nothing to remove
parent.children.splice(idx, 1);
return true; // removal successful
}
/** Narrow a generic `Node` to a `Parent` (i.e. one that really has children). */
function isParent(node: Node): node is Parent {
// A `Parent` always has a `children` array; make sure it's an array first.
return Array.isArray((node as Partial<Parent>).children);
}
/**
* Follow “last-child” pointers until you reach a leaf.
* Returns the right-most, deepest node in source order.
*/
export function findRightmostDeepestNode(root: Node): Node {
let current: Node = root;
// While the current node *is* a Parent and has at least one child…
while (isParent(current) && current.children.length > 0) {
const lastIndex = current.children.length - 1;
current = current.children[lastIndex];
}
return current; // Leaf: no further children
}
const remarkStreamingMarkdown: Plugin<[Options?], Root> = () => {
return (tree) => {
const treeWithParents = parents(tree);
const lastNode = findRightmostDeepestNode(treeWithParents) as Proxy;
const parentNode = lastNode.parent;
const grandparentNode = parentNode?.parent;
let ruleMatched = false;
// handling `* *` -> ``
//
// if the last node is part of a <list item (otherwise empty)> ->
// <list (otherwise empty)> -> <list item (last node, empty)>, then we need to
// remove everything up to and including the first list item. This happens
// when we have `* *`, which can become a bolded list item OR a horizontal
// line
if (
lastNode.type === "listItem" &&
parentNode &&
grandparentNode &&
parentNode.type === "list" &&
grandparentNode.type === "listItem" &&
parentNode.children.length === 1 &&
grandparentNode.children.length === 1
) {
ruleMatched = true;
if (grandparentNode.parent) {
removeChildFromParent(
grandparentNode.node as RootContent,
grandparentNode.parent.node,
);
}
// Handle `*` -> ``:
//
// if the last node is just an empty list item, we need to remove it
// because it could become something else (e.g., a horizontal line)
} else if (
lastNode.type === "listItem" &&
parentNode &&
parentNode.type === "list"
) {
ruleMatched = true;
removeChildFromParent(lastNode.node as RootContent, parentNode.node);
} else if (lastNode.type === "thematicBreak") {
ruleMatched = true;
const parent = lastNode.parent;
if (parent) {
removeChildFromParent(lastNode.node as RootContent, parent.node);
}
} else if (lastNode.type === "text") {
const textNode = lastNode.node as Text;
if (textNode.value.endsWith("**")) {
ruleMatched = true;
textNode.value = textNode.value.slice(0, -2);
// if there's a newline then a number, this is very very likely a
// numbered list item. Let's just hide it until the period comes (or
// other text disambiguates it)
} else {
const match = textNode.value.match(/^([0-9]+)$/m);
if (match) {
const number = match[1];
textNode.value = textNode.value.slice(0, -number.length - 1);
ruleMatched = true;
// if the text node is now empty, then we might want to remove other
// elements, like a now-empty containing paragraph, or a break that
// might disappear once more tokens come in
if (textNode.value.length === 0) {
if (
lastNode.parent?.type === "paragraph" &&
lastNode.parent.children.length === 1
) {
// remove the whole paragraph if it's now empty (otherwise it'll
// cause an extra newline that might not last)
removeChildFromParent(
lastNode.parent.node as Paragraph,
lastNode.parent.parent?.node as Node,
);
} else {
const prev = prevSibling(lastNode);
if (prev?.type === "break") {
removeChildFromParent(
prev.node as RootContent,
lastNode.parent?.node as Node,
);
removeChildFromParent(
lastNode.node as RootContent,
lastNode.parent?.node as Node,
);
}
}
}
}
}
}
if (ruleMatched) {
return tree;
}
// we need to
// a case like
// - *def `abc` [abc **def**](abc)*
// is pretty tricky, because if we land just after def, then we actually
// have two separate tags to process at two different parents. Maybe we
// need to keep iterating up until we find a paragraph, but process each
// parent on the way up. Hmm, well actually after `def` we won't even be a proper link yet
// TODO(drifkin): it's really if the last node's parent is a paragraph, for which the following is a sub-cas where the lastNode is a text node.
// And instead of just processing simple text rules, they need to operate on the whole paragraph
// like `**[abc](def)` needs to become `**[abc](def)**`
// if we're just text at the end, then we should remove some ambiguous characters
if (lastNode.parent) {
const didChange = processParent(lastNode.parent as Parent & Proxy);
if (didChange) {
// TODO(drifkin): need to fix up the tree, but not sure lastNode will still exist? Check all the transforms to see if it's safe to find the last node again
//
// need to regen the tree w/ parents since reparenting could've happened
// treeWithParents = parents(tree);
}
}
const grandparent = lastNode.parent?.parent;
// TODO(drifkin): let's go arbitrarily high up the tree, but limiting it
// to 2 levels for now until I think more about the stop condition
if (grandparent) {
processParent(grandparent as Parent & Proxy);
}
// console.log("ruleMatched", ruleMatched);
// } else if (lastNode.parent?.type === "paragraph") {
// console.log("!!! paragraph");
// console.log("lastNode.parent", lastNode.parent);
// // Handle `**abc*` -> `**abc**`:
// // We detect this when the last child is an emphasis node, and it's preceded by a text node that ends with `*`
// const paragraph = lastNode.parent as Proxy & Paragraph;
// if (paragraph.children.length >= 2) {
// const lastChild = paragraph.children[paragraph.children.length - 1];
// if (lastChild.type === "emphasis") {
// const sibling = paragraph.children[paragraph.children.length - 2];
// if (sibling.type === "text") {
// const siblingText = sibling as Text & Proxy;
// if (siblingText.value.endsWith("*")) {
// ruleMatched = true;
// const textNode = (lastNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// paragraph.node.type = "strong";
// }
// }
// }
// }
// } else if (lastNode.type === "text") {
// // Handle `**abc*` -> `**abc**`:
// //
// // this gets parsed as a text node ending in `*` followed by an emphasis
// // node. So if we're in text, we need to check if our parent is emphasis,
// // and then get our parent's sibling before it and check if it ends with
// // `*`
// const parent = lastNode.parent;
// if (parent && parent.type === "emphasis") {
// const grandparent = parent.parent;
// if (grandparent) {
// const index = (grandparent.node as Parent).children.indexOf(
// parent.node as RootContent,
// );
// if (index > 0) {
// const prevNode = grandparent.children[index - 1];
// if (
// prevNode.type === "text" &&
// (prevNode as Text).value.endsWith("*")
// ) {
// ruleMatched = true;
// const textNode = (prevNode as Proxy).node as Text;
// textNode.value = textNode.value.slice(0, -1);
// parent.node.type = "strong";
// }
// }
// }
// }
// if (!ruleMatched) {
// // if the last node is just text, then we process it in order to fix up certain unclosed items
// // e.g., `**abc` -> `**abc**`
// const textNode = lastNode.node as Text;
// for (const rule of simpleTextRules) {
// const matchesIterator = textNode.value.matchAll(rule.pattern);
// const matches = [...matchesIterator];
// if (matches.length > 0) {
// rule.transform(matches, lastNode);
// ruleMatched = true;
// break;
// }
// }
// }
// } else if (!ruleMatched) {
// // console.log("no rule matched", lastNode);
// }
return tree;
};
};
function processParent(parent: Parent & Proxy): boolean {
if (parent.type === "emphasis") {
// Handle `**abc*` -> `**abc**`:
// We detect this when we end with an emphasis node, and it's preceded by
// a text node that ends with `*`
// TODO(drifkin): the last node can be more deeply nested (e.g., a code
// literal in a link), so we probably need to walk up the tree until we
// find an emphasis node or a block? For now we'll just go up one layer to
// catch the most common cases
const emphasisNode = parent as Emphasis & Proxy;
const grandparent = emphasisNode.parent;
if (grandparent) {
const indexOfEmphasisNode = (grandparent.node as Parent).children.indexOf(
emphasisNode.node as RootContent,
);
if (indexOfEmphasisNode >= 0) {
const nodeBefore = grandparent.children[indexOfEmphasisNode - 1] as
| (Node & Proxy)
| undefined;
if (nodeBefore?.type === "text") {
const textNode = nodeBefore.node as Text;
if (textNode.value.endsWith("*")) {
const strBefore = textNode.value.slice(0, -1);
textNode.value = strBefore;
const strongNode = u("strong", {
children: emphasisNode.children,
});
(grandparent.node as Parent).children.splice(
indexOfEmphasisNode,
1,
strongNode,
);
return true;
}
}
}
}
}
// Let's check if we have any bold items to close
for (let i = parent.children.length - 1; i >= 0; i--) {
const child = parent.children[i];
if (child.type === "text") {
const textNode = child as Text & Proxy;
const sep = "**";
const index = textNode.value.lastIndexOf(sep);
if (index >= 0) {
let isValidOpening = false;
if (index + sep.length < textNode.value.length) {
const charAfter = textNode.value[index + sep.length];
if (!isWhitespace(charAfter)) {
isValidOpening = true;
}
} else {
if (i < parent.children.length - 1) {
// TODO(drifkin): I'm not sure that this check is strict enough.
// We're trying to detect cases like `**[abc]()` where the char
// after the opening ** is indeed a non-whitespace character. We're
// using the heuristic that there's another item after the current
// one, but I'm not sure if that is good enough. In a well
// constructed tree, there aren't two text nodes in a row, so this
// _seems_ good, but I should think through it more
isValidOpening = true;
}
}
if (isValidOpening) {
// TODO(drifkin): close the bold
const strBefore = textNode.value.slice(0, index);
const strAfter = textNode.value.slice(index + sep.length);
(textNode.node as Text).value = strBefore;
// TODO(drifkin): the node above could be empty in which case we probably want to delete it
const children: PhrasingContent[] = [
...(strAfter.length > 0 ? [u("text", { value: strAfter })] : []),
];
const strongNode: Strong = u("strong", {
children,
});
const nodesAfter = (parent.node as Parent).children.splice(
i + 1,
parent.children.length - i - 1,
strongNode,
);
// TODO(drifkin): this cast seems iffy, should see if we can cast the
// parent instead, which would also help us check some of our
// assumptions
strongNode.children.push(...(nodesAfter as PhrasingContent[]));
return true;
}
}
}
}
return false;
}
function prevSibling(node: Node & Proxy): (Node & Proxy) | null {
const parent = node.parent;
if (parent) {
const index = parent.children.indexOf(node);
return parent.children[index - 1] as Node & Proxy;
}
return null;
}
function isWhitespace(str: string) {
return str.trim() === "";
}
// function debugPrintTreeNoPos(tree: Node) {
// console.log(
// JSON.stringify(
// tree,
// (key, value) => {
// if (key === "position") {
// return undefined;
// }
// return value;
// },
// 2,
// ),
// );
// }
export default remarkStreamingMarkdown;

View File

@@ -101,14 +101,15 @@ type HealthResponse struct {
}
type User struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
Bio string `json:"bio,omitempty"`
AvatarURL string `json:"avatarurl,omitempty"`
FirstName string `json:"firstname,omitempty"`
LastName string `json:"lastname,omitempty"`
Plan string `json:"plan,omitempty"`
ID string `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
AvatarURL string `json:"avatarURL"`
Plan string `json:"plan"`
Bio string `json:"bio"`
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
OverThreshold bool `json:"overThreshold"`
}
type Attachment struct {

View File

@@ -23,6 +23,7 @@ import (
"github.com/google/uuid"
"github.com/ollama/ollama/api"
"github.com/ollama/ollama/app/auth"
"github.com/ollama/ollama/app/server"
"github.com/ollama/ollama/app/store"
"github.com/ollama/ollama/app/tools"
@@ -263,10 +264,11 @@ func (s *Server) Handler() http.Handler {
ollamaProxy := s.ollamaProxy()
mux.Handle("GET /api/tags", ollamaProxy)
mux.Handle("POST /api/show", ollamaProxy)
mux.Handle("GET /api/version", ollamaProxy)
mux.Handle("HEAD /api/version", ollamaProxy)
mux.Handle("POST /api/me", ollamaProxy)
mux.Handle("POST /api/signout", ollamaProxy)
mux.Handle("GET /api/v1/me", handle(s.me))
mux.Handle("POST /api/v1/disconnect", handle(s.disconnect))
mux.Handle("GET /api/v1/connect", handle(s.connectURL))
mux.Handle("GET /api/v1/health", handle(s.health))
// React app - catch all non-API routes and serve the React app
mux.Handle("GET /", s.appHandler())
@@ -336,7 +338,7 @@ func (s *Server) doSelfSigned(ctx context.Context, method, path string) (*http.R
}
// UserData fetches user data from ollama.com API for the current ollama key
func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
func (s *Server) UserData(ctx context.Context) (*responses.User, error) {
resp, err := s.doSelfSigned(ctx, http.MethodPost, "/api/me")
if err != nil {
return nil, fmt.Errorf("failed to call ollama.com/api/me: %w", err)
@@ -347,7 +349,7 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return nil, fmt.Errorf("unexpected status code: %d", resp.StatusCode)
}
var user api.UserResponse
var user responses.User
if err := json.NewDecoder(resp.Body).Decode(&user); err != nil {
return nil, fmt.Errorf("failed to parse user response: %w", err)
}
@@ -366,27 +368,29 @@ func (s *Server) UserData(ctx context.Context) (*api.UserResponse, error) {
return &user, nil
}
// WaitForServer waits for the Ollama server to be ready
func WaitForServer(ctx context.Context, timeout time.Duration) error {
deadline := time.Now().Add(timeout)
for time.Now().Before(deadline) {
func waitForServer(ctx context.Context) error {
timeout := time.Now().Add(10 * time.Second)
// TODO: this avoids an error on first load of the app
// however we should either show a loading state or
// wait for the Ollama server to be ready before redirecting
for {
c, err := api.ClientFromEnvironment()
if err != nil {
return err
}
if _, err := c.Version(ctx); err == nil {
slog.Debug("ollama server is ready")
return nil
break
}
if time.Now().After(timeout) {
return fmt.Errorf("timeout waiting for Ollama server to be ready")
}
time.Sleep(10 * time.Millisecond)
}
return errors.New("timeout waiting for Ollama server to be ready")
return nil
}
func (s *Server) createChat(w http.ResponseWriter, r *http.Request) error {
if err := WaitForServer(r.Context(), 10*time.Second); err != nil {
return err
}
waitForServer(r.Context())
id, err := uuid.NewV7()
if err != nil {
@@ -1434,6 +1438,129 @@ func (s *Server) settings(w http.ResponseWriter, r *http.Request) error {
})
}
func (s *Server) me(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
user, err := s.UserData(r.Context())
if err != nil {
// If fetching from API fails, try to return cached user data if available
if cachedUser, cacheErr := s.Store.User(); cacheErr == nil && cachedUser != nil {
s.log().Info("API request failed, returning cached user data", "error", err)
responseUser := &responses.User{
Name: cachedUser.Name,
Email: cachedUser.Email,
Plan: cachedUser.Plan,
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responseUser)
}
s.log().Error("failed to get user data", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get user data",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(user)
}
func (s *Server) disconnect(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodPost {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
if err := s.Store.ClearUser(); err != nil {
s.log().Warn("failed to clear cached user data", "error", err)
}
// Get the SSH public key to encode for the delete request
pubKey, err := ollamaAuth.GetPublicKey()
if err != nil {
s.log().Error("failed to get public key", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to get public key",
})
}
// Encode the key using base64 URL encoding
encodedKey := base64.RawURLEncoding.EncodeToString([]byte(pubKey))
// Call the /api/user/keys/{encodedKey} endpoint with DELETE
resp, err := s.doSelfSigned(r.Context(), http.MethodDelete, fmt.Sprintf("/api/user/keys/%s", encodedKey))
if err != nil {
s.log().Error("failed to call ollama.com/api/user/keys", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
s.log().Error("disconnect request failed", "status", resp.StatusCode)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to disconnect from ollama.com",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{"status": "disconnected"})
}
func (s *Server) connectURL(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
connectURL, err := auth.BuildConnectURL(OllamaDotCom)
if err != nil {
s.log().Error("failed to build connect URL", "error", err)
w.WriteHeader(http.StatusInternalServerError)
return json.NewEncoder(w).Encode(responses.Error{
Error: "failed to build connect URL",
})
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(map[string]string{
"connect_url": connectURL,
})
}
func (s *Server) health(w http.ResponseWriter, r *http.Request) error {
if r.Method != http.MethodGet {
http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed)
return nil
}
healthy := false
c, err := api.ClientFromEnvironment()
if err == nil {
if _, err := c.Version(r.Context()); err == nil {
healthy = true
}
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
return json.NewEncoder(w).Encode(responses.HealthResponse{
Healthy: healthy,
})
}
func (s *Server) getInferenceCompute(w http.ResponseWriter, r *http.Request) error {
ctx, cancel := context.WithTimeout(r.Context(), 500*time.Millisecond)
defer cancel()
@@ -1578,7 +1705,7 @@ func getStringFromMap(m map[string]any, key, defaultValue string) string {
// isImageAttachment checks if a filename is an image file
func isImageAttachment(filename string) bool {
ext := strings.ToLower(filename)
return strings.HasSuffix(ext, ".png") || strings.HasSuffix(ext, ".jpg") || strings.HasSuffix(ext, ".jpeg") || strings.HasSuffix(ext, ".webp")
return strings.HasSuffix(ext, ".png") || strings.HasSuffix(ext, ".jpg") || strings.HasSuffix(ext, ".jpeg")
}
// ptr is a convenience function for &literal
@@ -1667,14 +1794,13 @@ func (s *Server) buildChatRequest(chat *store.Chat, model string, think any, ava
var thinkValue *api.ThinkValue
if think != nil {
// Only set Think if it's actually requesting thinking
if boolValue, ok := think.(bool); ok {
if boolValue {
thinkValue = &api.ThinkValue{Value: boolValue}
thinkValue = &api.ThinkValue{
Value: boolValue,
}
} else if stringValue, ok := think.(string); ok {
if stringValue != "" && stringValue != "none" {
thinkValue = &api.ThinkValue{Value: stringValue}
thinkValue = &api.ThinkValue{
Value: stringValue,
}
}
}

View File

@@ -1,115 +0,0 @@
Ollama Benchmark Tool
---------------------
A Go-based command-line tool for benchmarking Ollama models with configurable parameters and multiple output formats.
## Features
* Benchmark multiple models in a single run
* Support for both text and image prompts
* Configurable generation parameters (temperature, max tokens, seed, etc.)
* Supports benchstat and CSV output formats
* Detailed performance metrics (prefill, generate, load, total durations)
## Building from Source
```
go build -o ollama-bench bench.go
./ollama-bench -model gpt-oss:20b -epochs 6 -format csv
```
Using Go Run (without building)
```
go run bench.go -model gpt-oss:20b -epochs 3
```
## Usage
### Basic Example
```
./ollama-bench -model gemma3 -epochs 6
```
### Benchmark Multiple Models
```
./ollama-bench -model gemma3,gemma3n -epochs 6 -max-tokens 100 -p "Write me a short story" | tee gemma.bench
benchstat -col /name gemma.bench
```
### With Image Prompt
```
./ollama-bench -model qwen3-vl -image photo.jpg -epochs 6 -max-tokens 100 -p "Describe this image"
```
### Advanced Example
```
./ollama-bench -model llama3 -epochs 10 -temperature 0.7 -max-tokens 500 -seed 42 -format csv -output results.csv
```
## Command Line Options
| Option | Description | Default |
|----------|-------------|---------|
| -model | Comma-separated list of models to benchmark | (required) |
| -epochs | Number of iterations per model | 1 |
| -max-tokens | Maximum tokens for model response | 0 (unlimited) |
| -temperature | Temperature parameter | 0.0 |
| -seed | Random seed | 0 (random) |
| -timeout | Timeout in seconds | 300 |
| -p | Prompt text | "Write a long story." |
| -image | Image file to include in prompt | |
| -k | Keep-alive duration in seconds | 0 |
| -format | Output format (benchstat, csv) | benchstat |
| -output | Output file for results | "" (stdout) |
| -v | Verbose mode | false |
| -debug | Show debug information | false |
## Output Formats
### Markdown Format
The default markdown format is suitable for copying and pasting into a GitHub issue and will look like:
```
Model | Step | Count | Duration | nsPerToken | tokensPerSec |
|-------|------|-------|----------|------------|--------------|
| gpt-oss:20b | prefill | 124 | 30.006458ms | 241987.56 | 4132.44 |
| gpt-oss:20b | generate | 200 | 2.646843954s | 13234219.77 | 75.56 |
| gpt-oss:20b | load | 1 | 121.674208ms | - | - |
| gpt-oss:20b | total | 1 | 2.861047625s | - | - |
```
### Benchstat Format
Compatible with Go's benchstat tool for statistical analysis:
```
BenchmarkModel/name=gpt-oss:20b/step=prefill 128 78125.00 ns/token 12800.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=generate 512 19531.25 ns/token 51200.00 token/sec
BenchmarkModel/name=gpt-oss:20b/step=load 1 1500000000 ns/request
```
### CSV Format
Machine-readable comma-separated values:
```
NAME,STEP,COUNT,NS_PER_COUNT,TOKEN_PER_SEC
gpt-oss:20b,prefill,128,78125.00,12800.00
gpt-oss:20b,generate,512,19531.25,51200.00
gpt-oss:20b,load,1,1500000000,0
```
## Metrics Explained
The tool reports four types of metrics for each model:
* prefill: Time spent processing the prompt
* generate: Time spent generating the response
* load: Model loading time (one-time cost)
* total: Total request duration

View File

@@ -1,321 +0,0 @@
package main
import (
"cmp"
"context"
"flag"
"fmt"
"io"
"os"
"runtime"
"slices"
"strings"
"sync"
"time"
"github.com/ollama/ollama/api"
)
type flagOptions struct {
models *string
epochs *int
maxTokens *int
temperature *float64
seed *int
timeout *int
prompt *string
imageFile *string
keepAlive *float64
format *string
outputFile *string
debug *bool
verbose *bool
}
type Metrics struct {
Model string
Step string
Count int
Duration time.Duration
}
var once sync.Once
const DefaultPrompt = `Please write a descriptive story about a llama named Alonso who grows up to be President of the Land of Llamas. Include details about Alonso's childhood, adolescent years, and how he grew up to be a political mover and shaker. Write the story with a sense of whimsy.`
func OutputMetrics(w io.Writer, format string, metrics []Metrics, verbose bool) {
switch format {
case "benchstat":
if verbose {
printHeader := func() {
fmt.Fprintf(w, "sysname: %s\n", runtime.GOOS)
fmt.Fprintf(w, "machine: %s\n", runtime.GOARCH)
}
once.Do(printHeader)
}
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
if m.Count > 0 {
nsPerToken := float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec := float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d %.2f ns/token %.2f token/sec\n",
m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "BenchmarkModel/name=%s/step=%s %d 0 ns/token 0 token/sec\n",
m.Model, m.Step, m.Count)
}
} else {
var suffix string
if m.Step == "load" {
suffix = "/step=load"
}
fmt.Fprintf(w, "BenchmarkModel/name=%s%s 1 %d ns/request\n",
m.Model, suffix, m.Duration.Nanoseconds())
}
}
case "csv":
printHeader := func() {
headings := []string{"NAME", "STEP", "COUNT", "NS_PER_COUNT", "TOKEN_PER_SEC"}
fmt.Fprintln(w, strings.Join(headings, ","))
}
once.Do(printHeader)
for _, m := range metrics {
if m.Step == "generate" || m.Step == "prefill" {
var nsPerToken float64
var tokensPerSec float64
if m.Count > 0 {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
}
fmt.Fprintf(w, "%s,%s,%d,%.2f,%.2f\n", m.Model, m.Step, m.Count, nsPerToken, tokensPerSec)
} else {
fmt.Fprintf(w, "%s,%s,1,%d,0\n", m.Model, m.Step, m.Duration.Nanoseconds())
}
}
case "markdown":
printHeader := func() {
fmt.Fprintln(w, "| Model | Step | Count | Duration | nsPerToken | tokensPerSec |")
fmt.Fprintln(w, "|-------|------|-------|----------|------------|--------------|")
}
once.Do(printHeader)
for _, m := range metrics {
var nsPerToken, tokensPerSec float64
var nsPerTokenStr, tokensPerSecStr string
if m.Step == "generate" || m.Step == "prefill" {
nsPerToken = float64(m.Duration.Nanoseconds()) / float64(m.Count)
tokensPerSec = float64(m.Count) / (float64(m.Duration.Nanoseconds()) + 1e-12) * 1e9
nsPerTokenStr = fmt.Sprintf("%.2f", nsPerToken)
tokensPerSecStr = fmt.Sprintf("%.2f", tokensPerSec)
} else {
nsPerTokenStr = "-"
tokensPerSecStr = "-"
}
fmt.Fprintf(w, "| %s | %s | %d | %v | %s | %s |\n",
m.Model, m.Step, m.Count, m.Duration, nsPerTokenStr, tokensPerSecStr)
}
default:
fmt.Fprintf(os.Stderr, "Unknown output format '%s'\n", format)
}
}
func BenchmarkChat(fOpt flagOptions) error {
models := strings.Split(*fOpt.models, ",")
// todo - add multi-image support
var imgData api.ImageData
var err error
if *fOpt.imageFile != "" {
imgData, err = readImage(*fOpt.imageFile)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't read image '%s': %v\n", *fOpt.imageFile, err)
return err
}
}
if *fOpt.debug && imgData != nil {
fmt.Fprintf(os.Stderr, "Read file '%s'\n", *fOpt.imageFile)
}
client, err := api.ClientFromEnvironment()
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: Couldn't create ollama client: %v\n", err)
return err
}
var out io.Writer = os.Stdout
if fOpt.outputFile != nil && *fOpt.outputFile != "" {
f, err := os.OpenFile(*fOpt.outputFile, os.O_CREATE|os.O_WRONLY, 0o644)
if err != nil {
fmt.Fprintf(os.Stderr, "ERROR: cannot open output file %s: %v\n", *fOpt.outputFile, err)
return err
}
defer f.Close()
out = f
}
for _, model := range models {
for range *fOpt.epochs {
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
var keepAliveDuration *api.Duration
if *fOpt.keepAlive > 0 {
duration := api.Duration{Duration: time.Duration(*fOpt.keepAlive * float64(time.Second))}
keepAliveDuration = &duration
}
req := &api.ChatRequest{
Model: model,
Messages: []api.Message{
{
Role: "user",
Content: *fOpt.prompt,
},
},
Options: options,
KeepAlive: keepAliveDuration,
}
if imgData != nil {
req.Messages[0].Images = []api.ImageData{imgData}
}
var responseMetrics *api.Metrics
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(*fOpt.timeout)*time.Second)
defer cancel()
err = client.Chat(ctx, req, func(resp api.ChatResponse) error {
if *fOpt.debug {
fmt.Fprintf(os.Stderr, "%s", cmp.Or(resp.Message.Thinking, resp.Message.Content))
}
if resp.Done {
responseMetrics = &resp.Metrics
}
return nil
})
if *fOpt.debug {
fmt.Fprintln(os.Stderr)
}
if err != nil {
if ctx.Err() == context.DeadlineExceeded {
fmt.Fprintf(os.Stderr, "ERROR: Chat request timed out with model '%s' after %vs\n", model, 1)
continue
}
fmt.Fprintf(os.Stderr, "ERROR: Couldn't chat with model '%s': %v\n", model, err)
continue
}
if responseMetrics == nil {
fmt.Fprintf(os.Stderr, "ERROR: No metrics received for model '%s'\n", model)
continue
}
metrics := []Metrics{
{
Model: model,
Step: "prefill",
Count: responseMetrics.PromptEvalCount,
Duration: responseMetrics.PromptEvalDuration,
},
{
Model: model,
Step: "generate",
Count: responseMetrics.EvalCount,
Duration: responseMetrics.EvalDuration,
},
{
Model: model,
Step: "load",
Count: 1,
Duration: responseMetrics.LoadDuration,
},
{
Model: model,
Step: "total",
Count: 1,
Duration: responseMetrics.TotalDuration,
},
}
OutputMetrics(out, *fOpt.format, metrics, *fOpt.verbose)
if *fOpt.keepAlive > 0 {
time.Sleep(time.Duration(*fOpt.keepAlive*float64(time.Second)) + 200*time.Millisecond)
}
}
}
return nil
}
func readImage(filePath string) (api.ImageData, error) {
file, err := os.Open(filePath)
if err != nil {
return nil, err
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
return nil, err
}
return api.ImageData(data), nil
}
func main() {
fOpt := flagOptions{
models: flag.String("model", "", "Model to benchmark"),
epochs: flag.Int("epochs", 6, "Number of epochs (iterations) per model"),
maxTokens: flag.Int("max-tokens", 200, "Maximum tokens for model response"),
temperature: flag.Float64("temperature", 0, "Temperature parameter"),
seed: flag.Int("seed", 0, "Random seed"),
timeout: flag.Int("timeout", 60*5, "Timeout in seconds (default 300s)"),
prompt: flag.String("p", DefaultPrompt, "Prompt to use"),
imageFile: flag.String("image", "", "Filename for an image to include"),
keepAlive: flag.Float64("k", 0, "Keep alive duration in seconds"),
format: flag.String("format", "markdown", "Output format [benchstat|csv] (default benchstat)"),
outputFile: flag.String("output", "", "Output file for results (stdout if empty)"),
verbose: flag.Bool("v", false, "Show system information"),
debug: flag.Bool("debug", false, "Show debug information"),
}
flag.Usage = func() {
fmt.Fprintf(os.Stderr, "Usage: %s [OPTIONS]\n\n", os.Args[0])
fmt.Fprintf(os.Stderr, "Description:\n")
fmt.Fprintf(os.Stderr, " Model benchmarking tool with configurable parameters\n\n")
fmt.Fprintf(os.Stderr, "Options:\n")
flag.PrintDefaults()
fmt.Fprintf(os.Stderr, "\nExamples:\n")
fmt.Fprintf(os.Stderr, " bench -model gpt-oss:20b -epochs 3 -temperature 0.7\n")
}
flag.Parse()
if !slices.Contains([]string{"markdown", "benchstat", "csv"}, *fOpt.format) {
fmt.Fprintf(os.Stderr, "ERROR: Unknown format '%s'\n", *fOpt.format)
os.Exit(1)
}
if len(*fOpt.models) == 0 {
fmt.Fprintf(os.Stderr, "ERROR: No model(s) specified to benchmark.\n")
flag.Usage()
return
}
BenchmarkChat(fOpt)
}

View File

@@ -1,463 +0,0 @@
package main
import (
"bytes"
"crypto/rand"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
"github.com/ollama/ollama/api"
)
func createTestFlagOptions() flagOptions {
models := "test-model"
format := "benchstat"
epochs := 1
maxTokens := 100
temperature := 0.7
seed := 42
timeout := 30
prompt := "test prompt"
imageFile := ""
keepAlive := 5.0
verbose := false
debug := false
return flagOptions{
models: &models,
format: &format,
epochs: &epochs,
maxTokens: &maxTokens,
temperature: &temperature,
seed: &seed,
timeout: &timeout,
prompt: &prompt,
imageFile: &imageFile,
keepAlive: &keepAlive,
verbose: &verbose,
debug: &debug,
}
}
func captureOutput(f func()) string {
oldStdout := os.Stdout
oldStderr := os.Stderr
defer func() {
os.Stdout = oldStdout
os.Stderr = oldStderr
}()
r, w, _ := os.Pipe()
os.Stdout = w
os.Stderr = w
f()
w.Close()
var buf bytes.Buffer
io.Copy(&buf, r)
return buf.String()
}
func createMockOllamaServer(t *testing.T, responses []api.ChatResponse) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path != "/api/chat" {
t.Errorf("Expected path /api/chat, got %s", r.URL.Path)
http.Error(w, "Not found", http.StatusNotFound)
return
}
if r.Method != "POST" {
t.Errorf("Expected POST method, got %s", r.Method)
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
for _, resp := range responses {
jsonData, err := json.Marshal(resp)
if err != nil {
t.Errorf("Failed to marshal response: %v", err)
return
}
w.Write(jsonData)
w.Write([]byte("\n"))
if f, ok := w.(http.Flusher); ok {
f.Flush()
}
time.Sleep(10 * time.Millisecond) // Simulate some delay
}
}))
}
func TestBenchmarkChat_Success(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 1",
},
Done: false,
},
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response part 2",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=prefill") {
t.Errorf("Expected output to contain prefill metrics, got: %s", output)
}
if !strings.Contains(output, "BenchmarkModel/name=test-model/step=generate") {
t.Errorf("Expected output to contain generate metrics, got: %s", output)
}
if !strings.Contains(output, "ns/token") {
t.Errorf("Expected output to contain ns/token metric, got: %s", output)
}
}
func TestBenchmarkChat_ServerError(t *testing.T) {
fOpt := createTestFlagOptions()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Error(w, "Internal server error", http.StatusInternalServerError)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected error to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Couldn't chat with model") {
t.Errorf("Expected error message about chat failure, got: %s", output)
}
}
func TestBenchmarkChat_Timeout(t *testing.T) {
fOpt := createTestFlagOptions()
shortTimeout := 1 // Very short timeout
fOpt.timeout = &shortTimeout
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate a long delay that will cause timeout
time.Sleep(2 * time.Second)
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected timeout to be handled internally, got returned error: %v", err)
}
})
if !strings.Contains(output, "ERROR: Chat request timed out") {
t.Errorf("Expected timeout error message, got: %s", output)
}
}
func TestBenchmarkChat_NoMetrics(t *testing.T) {
fOpt := createTestFlagOptions()
mockResponses := []api.ChatResponse{
{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response",
},
Done: false, // Never sends Done=true
},
}
server := createMockOllamaServer(t, mockResponses)
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "ERROR: No metrics received") {
t.Errorf("Expected no metrics error message, got: %s", output)
}
}
func TestBenchmarkChat_MultipleModels(t *testing.T) {
fOpt := createTestFlagOptions()
models := "model1,model2"
epochs := 2
fOpt.models = &models
fOpt.epochs = &epochs
callCount := 0
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
callCount++
w.Header().Set("Content-Type", "application/json")
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
response := api.ChatResponse{
Model: req.Model,
Message: api.Message{
Role: "assistant",
Content: "test response for " + req.Model,
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
// Should be called 4 times (2 models × 2 epochs)
if callCount != 4 {
t.Errorf("Expected 4 API calls, got %d", callCount)
}
if !strings.Contains(output, "BenchmarkModel/name=model1") || !strings.Contains(output, "BenchmarkModel/name=model2") {
t.Errorf("Expected output for both models, got: %s", output)
}
}
func TestBenchmarkChat_WithImage(t *testing.T) {
fOpt := createTestFlagOptions()
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
tmpfileName := tmpfile.Name()
fOpt.imageFile = &tmpfileName
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the request contains image data
var req api.ChatRequest
body, _ := io.ReadAll(r.Body)
json.Unmarshal(body, &req)
if len(req.Messages) == 0 || len(req.Messages[0].Images) == 0 {
t.Error("Expected request to contain images")
}
w.Header().Set("Content-Type", "application/json")
response := api.ChatResponse{
Model: "test-model",
Message: api.Message{
Role: "assistant",
Content: "test response with image",
},
Done: true,
Metrics: api.Metrics{
PromptEvalCount: 10,
PromptEvalDuration: 100 * time.Millisecond,
EvalCount: 50,
EvalDuration: 500 * time.Millisecond,
TotalDuration: 600 * time.Millisecond,
LoadDuration: 50 * time.Millisecond,
},
}
jsonData, _ := json.Marshal(response)
w.Write(jsonData)
}))
defer server.Close()
t.Setenv("OLLAMA_HOST", server.URL)
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
})
if !strings.Contains(output, "BenchmarkModel/name=test-model") {
t.Errorf("Expected benchmark output, got: %s", output)
}
}
func TestBenchmarkChat_ImageError(t *testing.T) {
randFileName := func() string {
const charset = "abcdefghijklmnopqrstuvwxyz0123456789"
const length = 8
result := make([]byte, length)
rand.Read(result) // Fill with random bytes
for i := range result {
result[i] = charset[result[i]%byte(len(charset))]
}
return string(result) + ".txt"
}
fOpt := createTestFlagOptions()
imageFile := randFileName()
fOpt.imageFile = &imageFile
output := captureOutput(func() {
err := BenchmarkChat(fOpt)
if err == nil {
t.Error("Expected error from image reading, got nil")
}
})
if !strings.Contains(output, "ERROR: Couldn't read image") {
t.Errorf("Expected image read error message, got: %s", output)
}
}
func TestReadImage_Success(t *testing.T) {
tmpfile, err := os.CreateTemp(t.TempDir(), "testimage")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpfile.Name())
content := []byte("fake image data")
if _, err := tmpfile.Write(content); err != nil {
t.Fatalf("Failed to write to temp file: %v", err)
}
tmpfile.Close()
imgData, err := readImage(tmpfile.Name())
if err != nil {
t.Errorf("Expected no error, got %v", err)
}
if imgData == nil {
t.Error("Expected image data, got nil")
}
expected := api.ImageData(content)
if string(imgData) != string(expected) {
t.Errorf("Expected image data %v, got %v", expected, imgData)
}
}
func TestReadImage_FileNotFound(t *testing.T) {
imgData, err := readImage("nonexistentfile.jpg")
if err == nil {
t.Error("Expected error for non-existent file, got nil")
}
if imgData != nil {
t.Error("Expected nil image data for non-existent file")
}
}
func TestOptionsMapCreation(t *testing.T) {
fOpt := createTestFlagOptions()
options := make(map[string]interface{})
if *fOpt.maxTokens > 0 {
options["num_predict"] = *fOpt.maxTokens
}
options["temperature"] = *fOpt.temperature
if fOpt.seed != nil && *fOpt.seed > 0 {
options["seed"] = *fOpt.seed
}
if options["num_predict"] != *fOpt.maxTokens {
t.Errorf("Expected num_predict %d, got %v", *fOpt.maxTokens, options["num_predict"])
}
if options["temperature"] != *fOpt.temperature {
t.Errorf("Expected temperature %f, got %v", *fOpt.temperature, options["temperature"])
}
if options["seed"] != *fOpt.seed {
t.Errorf("Expected seed %d, got %v", *fOpt.seed, options["seed"])
}
}

View File

@@ -1430,7 +1430,7 @@ func chat(cmd *cobra.Command, opts runOptions) (*api.Message, error) {
latest.Summary()
}
return &api.Message{Role: role, Thinking: thinkingContent.String(), Content: fullResponse.String()}, nil
return &api.Message{Role: role, Content: fullResponse.String()}, nil
}
func generate(cmd *cobra.Command, opts runOptions) error {

View File

@@ -202,16 +202,10 @@ func ConvertModel(fsys fs.FS, f *os.File) error {
conv = &qwen3VLModel{}
case "BertModel":
conv = &bertModel{}
case "NomicBertModel", "NomicBertMoEModel":
conv = &nomicbertModel{}
case "CohereForCausalLM":
conv = &commandrModel{}
case "GptOssForCausalLM":
conv = &gptossModel{}
case "DeepseekOCRForCausalLM":
conv = &deepseekocr{}
case "DeepseekV3ForCausalLM":
conv = &deepseek2Model{}
default:
return fmt.Errorf("unsupported architecture %q", p.Architectures[0])
}

View File

@@ -1,173 +0,0 @@
package convert
import (
"cmp"
"fmt"
"log/slog"
"regexp"
"strconv"
"github.com/ollama/ollama/fs/ggml"
)
type deepseek2Model struct {
ModelParameters // architectures, vocab_size
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
RopeTheta float32 `json:"rope_theta"`
QKNopeHeadDim uint32 `json:"qk_nope_head_dim"`
QKRopeHeadDim uint32 `json:"qk_rope_head_dim"`
KVLoraRank uint32 `json:"kv_lora_rank"`
QLoraRank uint32 `json:"q_lora_rank"`
VHeadDim uint32 `json:"v_head_dim"`
ExpertCount uint32 `json:"n_routed_experts"`
ExpertSharedCount uint32 `json:"n_shared_experts"`
ExpertIntermediateSize uint32 `json:"moe_intermediate_size"`
ExpertUsedCount uint32 `json:"num_experts_per_tok"`
ExpertWeightsNorm bool `json:"norm_topk_prob"`
ExpertWeightsScale float32 `json:"routed_scaling_factor"`
ScoringFunc string `json:"scoring_func"`
LeadingDenseBlockCount uint32 `json:"first_k_dense_replace"`
RopeScaling struct {
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
Type string `json:"type"`
MScaleAllDim float32 `json:"mscale_all_dim"`
} `json:"rope_scaling"`
Architecture string
}
func (p *deepseek2Model) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
kv["general.architecture"] = "deepseek2"
kv["general.type"] = "model"
kv["deepseek2.block_count"] = p.HiddenLayers
numHeads := p.NumAttentionHeads
numKVHeads := p.NumKeyValueHeads
kv["deepseek2.attention.head_count"] = numHeads
kv["deepseek2.attention.head_count_kv"] = numKVHeads
kv["deepseek2.attention.key_length"] = p.QKNopeHeadDim + p.QKRopeHeadDim
kv["deepseek2.attention.kv_lora_rank"] = p.KVLoraRank
kv["deepseek2.attention.layer_norm_rms_epsilon"] = p.RMSNormEPS
kv["deepseek2.attention.q_lora_rank"] = p.QLoraRank
kv["deepseek2.attention.value_length"] = p.VHeadDim
kv["deepseek2.context_length"] = p.MaxPositionEmbeddings
kv["deepseek2.embedding_length"] = p.HiddenSize
kv["deepseek2.expert_count"] = p.ExpertCount
kv["deepseek2.expert_feed_forward_length"] = p.ExpertIntermediateSize
kv["deepseek2.expert_shared_count"] = p.ExpertSharedCount
var scoringFunc uint32
switch p.ScoringFunc {
case "softmax":
// not currently supported in the model, but needed for Deepseek-OCR
scoringFunc = 1
case "sigmoid":
scoringFunc = 2
}
kv["deepseek2.expert_gating_func"] = scoringFunc
kv["deepseek2.expert_used_count"] = p.ExpertUsedCount
kv["deepseek2.expert_weights_norm"] = p.ExpertWeightsNorm
kv["deepseek2.expert_weights_scale"] = p.ExpertWeightsScale
kv["deepseek2.feed_forward_length"] = p.IntermediateSize
kv["deepseek2.leading_dense_block_count"] = p.LeadingDenseBlockCount
kv["deepseek2.rope.dimension_count"] = p.QKRopeHeadDim
kv["deepseek2.rope.freq_base"] = cmp.Or(p.RopeTheta, 10000.0)
kv["deepseek2.rope.scaling.factor"] = p.RopeScaling.Factor
kv["deepseek2.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["deepseek2.rope.scaling.type"] = p.RopeScaling.Type
kv["deepseek2.rope.scaling.yarn_log_multiplier"] = 0.1 * p.RopeScaling.MScaleAllDim
kv["tokenizer.ggml.pre"] = "deepseek-v3"
return kv
}
func (p *deepseek2Model) Replacements() []string {
return []string{
"lm_head", "output",
"model.embed_tokens", "token_embd",
"model.norm", "output_norm",
"language_model.", "",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.kv_a_proj_with_mqa", "attn_kv_a_mqa",
"self_attn.kv_a_layernorm", "attn_kv_a_norm",
"self_attn.kv_b_proj", "attn_kv_b",
"self_attn.q_a_proj", "attn_q_a",
"self_attn.q_a_layernorm", "attn_q_a_norm",
"self_attn.q_b_proj", "attn_q_b",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.gate_proj", "ffn_gate",
"mlp.down_proj", "ffn_down",
"mlp.up_proj", "ffn_up",
"mlp.gate.e_score_correction_bias", "exp_probs_b.bias",
"mlp.gate", "ffn_gate_inp",
}
}
func (p *deepseek2Model) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, p.HiddenLayers*3)
for i := range p.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
skipLayer := func(n string, minValue uint32) bool {
re := regexp.MustCompile(`^blk\.(\d+)`)
matches := re.FindStringSubmatch(n)
if matches == nil {
return false
}
blkNum, err := strconv.Atoi(matches[1])
if err != nil {
return false
}
return uint32(blkNum) >= minValue
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
// skip any additional layers (such as the Multi-Token Prediction layer)
if skipLayer(t.Name(), p.HiddenLayers) {
slog.Debug("skipping layer", "name", t.Name())
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}

View File

@@ -1,136 +0,0 @@
package convert
import (
"fmt"
"github.com/ollama/ollama/fs/ggml"
)
type deepseekocr struct {
ModelParameters
LanguageConfig struct {
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
HiddenLayers uint32 `json:"num_hidden_layers"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
NumRoutedExperts uint32 `json:"n_routed_experts"`
NumSharedExperts uint32 `json:"n_shared_experts"`
NumExpertsPerToken uint32 `json:"num_experts_per_tok"`
FirstKDenseReplace uint32 `json:"first_k_dense_replace"`
} `json:"language_config"`
VisionConfig struct {
ImageSize uint32 `json:"image_size"`
Width struct {
Vision struct {
Heads uint32 `json:"heads"`
ImageSize uint32 `json:"image_size"`
Layers uint32 `json:"layers"`
PatchSize uint32 `json:"patch_size"`
Width uint32 `json:"width"`
} `json:"clip-l-14-224"`
Sam struct {
GlobalAttentionIndexes []int32 `json:"global_attn_indexes"`
Heads uint32 `json:"heads"`
Layers uint32 `json:"layers"`
Width uint32 `json:"width"`
} `json:"sam_vit_b"`
}
} `json:"vision_config"`
}
func (m *deepseekocr) KV(t *Tokenizer) ggml.KV {
kv := m.ModelParameters.KV(t)
kv["general.architecture"] = "deepseekocr"
kv["block_count"] = m.LanguageConfig.HiddenLayers
kv["context_length"] = m.LanguageConfig.MaxPositionEmbeddings
kv["embedding_length"] = m.LanguageConfig.HiddenSize
kv["feed_forward_length"] = m.LanguageConfig.IntermediateSize
kv["attention.head_count"] = m.LanguageConfig.NumAttentionHeads
kv["attention.head_count_kv"] = m.LanguageConfig.NumKeyValueHeads
kv["expert_count"] = m.LanguageConfig.NumRoutedExperts
kv["expert_used_count"] = m.LanguageConfig.NumExpertsPerToken
kv["leading_dense_block_count"] = m.LanguageConfig.FirstKDenseReplace
kv["vision.block_count"] = m.VisionConfig.Width.Vision.Layers
kv["vision.embedding_length"] = m.VisionConfig.Width.Vision.Width
kv["vision.head_count"] = m.VisionConfig.Width.Vision.Heads
kv["vision.image_size"] = m.VisionConfig.Width.Vision.ImageSize
kv["vision.patch_size"] = m.VisionConfig.Width.Vision.PatchSize
kv["sam.block_count"] = m.VisionConfig.Width.Sam.Layers
kv["sam.embedding_length"] = m.VisionConfig.Width.Sam.Width
kv["sam.head_count"] = m.VisionConfig.Width.Sam.Heads
kv["sam.global_attention_indexes"] = m.VisionConfig.Width.Sam.GlobalAttentionIndexes
return kv
}
func (m *deepseekocr) Tensors(s []Tensor) (out []*ggml.Tensor) {
merges := make([]merge, m.LanguageConfig.HiddenLayers*3)
for i := range m.LanguageConfig.HiddenLayers {
merges[i*3+0] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.gate_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_gate_exps.weight", i),
}
merges[i*3+1] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.up_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_up_exps.weight", i),
}
merges[i*3+2] = merge{
fmt.Sprintf("blk.%d.mlp.experts.*.down_proj.weight", i),
fmt.Sprintf("blk.%d.ffn_down_exps.weight", i),
}
}
out, s = mergeTensors(s, merges...)
for _, t := range s {
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (m *deepseekocr) Replacements() []string {
return []string{
"model.embed_tokens", "token_embd",
"model.layers", "blk",
"input_layernorm", "attn_norm",
"self_attn.q_proj", "attn_q",
"self_attn.k_proj", "attn_k",
"self_attn.v_proj", "attn_v",
"self_attn.o_proj", "attn_output",
"post_attention_layernorm", "ffn_norm",
"mlp.gate_proj", "ffn_gate",
"mlp.up_proj", "ffn_up",
"mlp.down_proj", "ffn_down",
"mlp.gate", "ffn_gate_inp",
"mlp.shared_experts.gate_proj", "ffn_gate_shexp",
"mlp.shared_experts.up_proj", "ffn_up_shexp",
"mlp.shared_experts.down_proj", "ffn_down_shexp",
"model.norm", "output_norm",
"lm_head", "output",
"model.vision_model", "v",
"embeddings.patch_embedding", "patch_embd",
"embeddings.class_embedding", "class_embd",
"embeddings.position_embedding", "position_embd",
"transformer.layers", "blk",
"model.projector", "mm",
"model.image_newline", "mm.image_newline",
//nolint:misspell // this misspelling is upstream. fixing it breaks the model
"model.view_seperator", "mm.view_seperator",
"model.sam_model.patch_embed.proj", "s.patch_embd",
"model.sam_model.pos_embed", "s.position_embd",
"model.sam_model.blocks", "s.blk",
"model.sam_model.neck", "s.neck",
"model.sam_model.net_", "s.net_",
}
}

View File

@@ -2,7 +2,6 @@ package convert
import (
"cmp"
"slices"
"github.com/ollama/ollama/fs/ggml"
)
@@ -27,26 +26,16 @@ type gemma3Model struct {
NumChannels uint32 `json:"num_channels"` // num_channels 3
PatchSize uint32 `json:"patch_size"` // patch_size 14
} `json:"vision_config"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeTheta float32 `json:"rope_theta"`
SlidingWindow uint32 `json:"sliding_window"`
SlidingWindowPattern *uint32 `json:"sliding_window_pattern"`
LayerTypes []string `json:"layer_types"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
RopeScaling *struct {
Type string `json:"rope_type"`
Factor float32 `json:"factor"`
OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
ExtrapolationFactor float32 `json:"extrapolation_factor"`
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
} `json:"rope_scaling"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
RMSNormEPS float32 `json:"rms_norm_eps"`
HeadDim uint32 `json:"head_dim"`
FinalLogitSoftcap float32 `json:"final_logit_softcapping"`
RopeLocalTheta float32 `json:"rope_local_base_freq"`
RopeGlobalTheta float32 `json:"rope_global_base_freq"`
SlidingWindow uint32 `json:"sliding_window"`
MultiModalTokensPerImage uint32 `json:"mm_tokens_per_image"`
}
const (
@@ -92,38 +81,9 @@ func (p *gemma3Model) KV(t *Tokenizer) ggml.KV {
kv["gemma3.attention.key_length"] = p.HeadDim
kv["gemma3.attention.value_length"] = p.HeadDim
kv["gemma3.attention.sliding_window"] = p.SlidingWindow
// The sliding window pattern is either provided as the sliding_window_pattern
// key (an int) or as the layer_types key (a list of strings).
if p.SlidingWindowPattern != nil || len(p.LayerTypes) > 0 {
kv["gemma3.attention.sliding_window_pattern"] = slices.Collect(func(yield func(bool) bool) {
for i := range numBlocks {
var isLocal bool
if len(p.LayerTypes) > 0 && int(i) < len(p.LayerTypes) {
isLocal = p.LayerTypes[i] == "sliding_attention"
} else if p.SlidingWindowPattern != nil && *p.SlidingWindowPattern > 0 {
isLocal = (i+1)%*p.SlidingWindowPattern != 0
}
if !yield(isLocal) {
break
}
}
})
}
if p.FinalLogitSoftcap > 0 {
kv["gemma3.final_logit_softcapping"] = p.FinalLogitSoftcap
}
kv["gemma3.final_logit_softcapping"] = cmp.Or(p.FinalLogitSoftcap, 30)
kv["gemma3.rope.local.freq_base"] = cmp.Or(p.RopeLocalTheta, 10000.0)
kv["gemma3.rope.freq_base"] = cmp.Or(p.RopeTheta, 1000000.0)
if p.RopeScaling != nil && p.RopeScaling.Type == "yarn" && p.RopeScaling.Factor > 0 {
kv["gemma3.rope.scaling.type"] = "yarn"
kv["gemma3.rope.scaling.factor"] = p.RopeScaling.Factor
kv["gemma3.rope.scaling.original_context_length"] = p.RopeScaling.OriginalMaxPositionEmbeddings
kv["gemma3.rope.scaling.extrapolation_factor"] = cmp.Or(p.RopeScaling.ExtrapolationFactor, float32(1.0))
kv["gemma3.rope.scaling.beta_fast"] = cmp.Or(p.RopeScaling.BetaFast, float32(64.0))
kv["gemma3.rope.scaling.beta_slow"] = cmp.Or(p.RopeScaling.BetaSlow, float32(1.0))
}
kv["gemma3.rope.global.freq_base"] = cmp.Or(p.RopeGlobalTheta, 1000000.0)
kv["gemma3.embedding_length"] = p.HiddenSize
kv["gemma3.feed_forward_length"] = p.IntermediateSize
default:

View File

@@ -110,12 +110,9 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
for name, mxfp4 := range mxfp4s {
dims := mxfp4.blocks.Shape()
if !strings.HasSuffix(name, ".weight") {
name = name + ".weight"
}
if strings.Contains(name, "ffn_down_exps") {
out = append(out, &ggml.Tensor{
Name: name,
Name: name + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1], dims[2] * dims[3] * 2},
WriterTo: mxfp4,
@@ -124,12 +121,12 @@ func (m *gptossModel) Tensors(ts []Tensor) []*ggml.Tensor {
// gate_up_exps is interleaved, need to split into gate_exps and up_exps
// e.g. gate_exps, up_exps = gate_up_exps[:, 0::2, ...], gate_up_exps[:, 1::2, ...]
out = append(out, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "gate", 1),
Name: strings.Replace(name, "gate_up", "gate", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 0, int(dims[1]), 2),
}, &ggml.Tensor{
Name: strings.Replace(name, "gate_up", "up", 1),
Name: strings.Replace(name, "gate_up", "up", 1) + ".weight",
Kind: uint32(ggml.TensorTypeMXFP4),
Shape: []uint64{dims[0], dims[1] / 2, dims[2] * dims[3] * 2},
WriterTo: mxfp4.slice(1, 1, int(dims[1]), 2),

View File

@@ -29,15 +29,6 @@ type mistral3Model struct {
SlidingWindow *uint32 `json:"sliding_window"`
HiddenAct string `json:"hidden_act"`
VocabSize uint32 `json:"vocab_size"`
RopeParameters struct {
BetaFast float32 `json:"beta_fast"`
BetaSlow float32 `json:"beta_slow"`
Factor float32 `json:"factor"`
ScalingBeta float32 `json:"llama_4_scaling_beta"`
OrigMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"`
RopeType string `json:"rope_type"`
RopeTheta float32 `json:"rope_theta"`
} `json:"rope_parameters"`
} `json:"text_config"`
VisionModel struct {
NumAttentionHeads uint32 `json:"num_attention_heads"`
@@ -70,13 +61,8 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.attention.layer_norm_rms_epsilon"] = p.TextModel.RMSNormEPS
kv["mistral3.attention.key_length"] = p.TextModel.HeadDim
kv["mistral3.attention.value_length"] = p.TextModel.HeadDim
kv["mistral3.rope.dimension_count"] = cmp.Or(p.TextModel.HeadDim, p.TextModel.HiddenSize/p.TextModel.NumAttentionHeads)
kv["mistral3.rope.freq_base"] = cmp.Or(p.TextModel.RopeTheta, p.TextModel.RopeParameters.RopeTheta)
if p.TextModel.RopeParameters.OrigMaxPositionEmbeddings > 0 {
kv["mistral3.rope.scaling.original_context_length"] = p.TextModel.RopeParameters.OrigMaxPositionEmbeddings
kv["mistral3.rope.scaling_beta"] = p.TextModel.RopeParameters.ScalingBeta
}
kv["mistral3.rope.dimension_count"] = p.TextModel.HiddenSize / p.TextModel.NumHiddenLayers
kv["mistral3.rope.freq_base"] = p.TextModel.RopeTheta
// Vision configuration
kv["mistral3.vision.block_count"] = p.VisionModel.NumHiddenLayers

View File

@@ -1,213 +0,0 @@
package convert
import (
"cmp"
"encoding/json"
"io/fs"
"path/filepath"
"slices"
"strings"
"github.com/ollama/ollama/fs/ggml"
)
type nomicbertModel struct {
ModelParameters
NLayers uint32 `json:"n_layers"`
NumHiddenLayers uint32 `json:"num_hidden_layers"`
MaxPositionEmbeddings uint32 `json:"max_position_embeddings"`
HiddenSize uint32 `json:"hidden_size"`
IntermediateSize uint32 `json:"intermediate_size"`
NumAttentionHeads uint32 `json:"num_attention_heads"`
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
LayerNormEPS float32 `json:"layer_norm_eps"`
LayerNormEpsilon float32 `json:"layer_norm_epsilon"`
RopeFreqBase float32 `json:"rope_theta"`
normalizeEmbeddings bool
PoolingType uint32
// MoE parameters (only present in v2 models)
NumExperts uint32 `json:"num_local_experts"`
NumExpertsUsed uint32 `json:"num_experts_per_tok"`
MoEEveryNLayers uint32 `json:"moe_every_n_layers"`
}
var (
_ ModelConverter = (*nomicbertModel)(nil)
_ moreParser = (*nomicbertModel)(nil)
)
func (p *nomicbertModel) parseMore(fsys fs.FS) error {
bts, err := fs.ReadFile(fsys, "modules.json")
if err != nil {
return err
}
var modules []struct {
Type string `json:"type"`
Path string `json:"path"`
}
if err := json.Unmarshal(bts, &modules); err != nil {
return err
}
var pooling string
for _, m := range modules {
switch m.Type {
case "sentence_transformers.models.Pooling":
pooling = m.Path
case "sentence_transformers.models.Normalize":
p.normalizeEmbeddings = true
}
}
if pooling != "" {
bts, err := fs.ReadFile(fsys, filepath.Join(pooling, "config.json"))
if err != nil {
return err
}
var pc struct {
PoolingModeCLSToken bool `json:"pooling_mode_cls_token"`
PoolingModeMeanTokens bool `json:"pooling_mode_mean_tokens"`
}
if err := json.Unmarshal(bts, &pc); err != nil {
return err
}
if pc.PoolingModeMeanTokens {
p.PoolingType = 1
} else if pc.PoolingModeCLSToken {
p.PoolingType = 2
}
}
return nil
}
func (p *nomicbertModel) KV(t *Tokenizer) ggml.KV {
kv := p.ModelParameters.KV(t)
// Determine architecture based on MoE parameters (following qwen3 pattern)
arch := "nomic-bert"
if p.MoEEveryNLayers > 0 {
arch += "-moe"
}
kv["general.architecture"] = arch
kv["attention.causal"] = false
kv["pooling_type"] = p.PoolingType
kv["normalize_embeddings"] = p.normalizeEmbeddings
kv["block_count"] = cmp.Or(p.NLayers, p.NumHiddenLayers)
if contextLength := p.MaxPositionEmbeddings; contextLength > 0 {
kv["context_length"] = contextLength
}
if embeddingLength := p.HiddenSize; embeddingLength > 0 {
kv["embedding_length"] = p.HiddenSize
}
if feedForwardLength := p.IntermediateSize; feedForwardLength > 0 {
kv["feed_forward_length"] = p.IntermediateSize
}
if headCount := p.NumAttentionHeads; headCount > 0 {
kv["attention.head_count"] = p.NumAttentionHeads
}
if kvHeadCount := p.NumKeyValueHeads; kvHeadCount > 0 {
kv["attention.head_count_kv"] = p.NumKeyValueHeads
}
if layerNormEpsilon := cmp.Or(p.LayerNormEPS, p.LayerNormEpsilon); layerNormEpsilon > 0 {
kv["attention.layer_norm_epsilon"] = layerNormEpsilon
}
if p.RopeFreqBase > 0 {
kv["rope.freq_base"] = p.RopeFreqBase
}
// MoE specific parameters (only if MoE is enabled)
if p.NumExperts > 0 {
kv["expert_count"] = p.NumExperts
}
if p.NumExpertsUsed > 0 {
kv["expert_used_count"] = p.NumExpertsUsed
}
if p.MoEEveryNLayers > 0 {
kv["moe_every_n_layers"] = p.MoEEveryNLayers
}
kv["tokenizer.ggml.model"] = "bert"
kv["tokenizer.ggml.token_type_count"] = uint32(2)
// convert to phantom space tokens
for i, e := range t.Tokens {
switch {
case strings.HasPrefix(e, "[") && strings.HasSuffix(e, "]"):
// noop - keep special tokens as-is
case strings.HasPrefix(e, "##"):
t.Tokens[i] = e[2:]
default:
t.Tokens[i] = "\u2581" + e
}
}
kv["tokenizer.ggml.tokens"] = t.Tokens
return kv
}
func (p *nomicbertModel) Tensors(ts []Tensor) []*ggml.Tensor {
out := make([]*ggml.Tensor, 0, len(ts))
for _, t := range ts {
if slices.Contains([]string{
"embeddings.position_ids",
"pooler.dense.weight",
"pooler.dense.bias",
}, t.Name()) {
continue
}
out = append(out, &ggml.Tensor{
Name: t.Name(),
Kind: t.Kind(),
Shape: t.Shape(),
WriterTo: t,
})
}
return out
}
func (nomicbertModel) Replacements() []string {
return []string{
"encoder.layer", "blk",
"encoder.layers", "blk",
"embeddings.word_embeddings", "token_embd",
"embeddings.token_type_embeddings", "token_types",
"embeddings.LayerNorm", "token_embd_norm",
"attention.self.qkv", "attn_qkv",
"attention.output.dense", "attn_output",
"attention.output.LayerNorm", "attn_output_norm",
"mlp.up", "ffn_up",
"mlp.down", "ffn_down",
"mlp.router", "ffn_gate_inp",
"mlp.experts.up", "ffn_up_exps",
"mlp.experts.down", "ffn_down_exps",
"intermediate.dense", "ffn_up",
"output.dense", "ffn_down",
"output.LayerNorm", "layer_output_norm",
}
}

View File

@@ -44,10 +44,7 @@ func (t tensorBase) Kind() uint32 {
t.name == "v.positional_embedding_vlm" ||
t.name == "v.tile_position_embd.weight" ||
t.name == "v.pre_tile_position_embd.weight" ||
t.name == "v.post_tile_position_embd.weight" ||
t.name == "s.position_embd" ||
strings.HasSuffix(t.name, "rel_pos_h") ||
strings.HasSuffix(t.name, "rel_pos_w") {
t.name == "v.post_tile_position_embd.weight" {
// these tensors are always F32
return tensorKindFP32
}

View File

@@ -37,10 +37,6 @@ func parseSafetensors(fsys fs.FS, replacer *strings.Replacer, ps ...string) ([]T
return nil, err
}
if n <= 0 || n > 100<<20 {
return nil, fmt.Errorf("invalid safetensors file %q (header size: %d): file may be corrupted or a Git LFS pointer", p, n)
}
b := bytes.NewBuffer(make([]byte, 0, n))
if _, err = io.CopyN(b, f, n); err != nil {
return nil, err
@@ -100,10 +96,7 @@ type safetensor struct {
func (st safetensor) Kind() uint32 {
kind := st.tensorBase.Kind()
if st.dtype == "BF16" &&
!strings.HasPrefix(st.name, "v.") &&
!strings.HasPrefix(st.name, "s.") &&
kind != tensorKindFP32 {
if !strings.HasPrefix(st.name, "v.") && st.dtype == "BF16" && kind != tensorKindFP32 {
kind = tensorKindBF16
}

View File

@@ -2,12 +2,10 @@ package convert
import (
"cmp"
"errors"
"io"
"iter"
"path"
"slices"
"strconv"
"strings"
"github.com/pdevine/tensor"
@@ -96,26 +94,6 @@ func mergeTensors(unmatched []Tensor, merges ...merge) (out []*ggml.Tensor, _ []
return matched
})
slices.SortStableFunc(matched, func(a, b Tensor) int {
x := strings.Split(a.Name(), ".")
y := strings.Split(b.Name(), ".")
if len(x) != len(y) {
return cmp.Compare(len(x), len(y))
}
vals := make([]int, len(x))
for i := range x {
vals[i] = strings.Compare(x[i], y[i])
m, err := strconv.ParseInt(x[i], 0, 0)
n, err2 := strconv.ParseInt(y[i], 0, 0)
if errors.Join(err, err2) == nil {
vals[i] = cmp.Compare(m, n)
}
}
return cmp.Or(vals...)
})
if len(matched) > 0 {
out = append(out, &ggml.Tensor{
Name: merges[i].name,

View File

@@ -3,10 +3,8 @@ package convert
import (
"bytes"
"encoding/binary"
"fmt"
"io"
"iter"
"math/rand/v2"
"slices"
"strings"
"testing"
@@ -953,45 +951,3 @@ func TestMerge(t *testing.T) {
}
})
}
func TestMergeOrder(t *testing.T) {
for range 8 {
t.Run("", func(t *testing.T) {
tensors := make([]Tensor, 16)
for i := range tensors {
tensors[i] = &fakeTensor{
name: fmt.Sprintf("layer.%d.weight", i),
shape: []uint64{1},
data: []float32{float32(i)},
}
}
rand.Shuffle(len(tensors), func(i, j int) {
tensors[i], tensors[j] = tensors[j], tensors[i]
})
matched, unmatched := mergeTensors(tensors, merge{"layer.*.weight", "layer.weight"})
if len(unmatched) != 0 {
t.Error("expected no remaining tensors, got", len(unmatched))
}
if len(matched) != 1 {
t.Error("expected 1 merged tensor, got", len(matched))
}
var b bytes.Buffer
if _, err := matched[0].WriteTo(&b); err != nil {
t.Fatal(err)
}
var f32s [16]float32
if err := binary.Read(&b, binary.LittleEndian, &f32s); err != nil {
t.Fatal(err)
}
if !slices.IsSorted(f32s[:]) {
t.Errorf("merged tensor data is not in order: %+v", f32s)
}
})
}
}

View File

@@ -2,7 +2,6 @@ package discover
import (
"bufio"
"errors"
"fmt"
"io"
"log/slog"
@@ -11,21 +10,12 @@ import (
"reflect"
"regexp"
"sort"
"strconv"
"strings"
"github.com/ollama/ollama/format"
)
func GetCPUMem() (memInfo, error) {
mem, err := getCPUMem()
if err != nil {
return memInfo{}, err
}
return getCPUMemByCgroups(mem), nil
}
func getCPUMem() (memInfo, error) {
var mem memInfo
var total, available, free, buffers, cached, freeSwap uint64
f, err := os.Open("/proc/meminfo")
@@ -66,32 +56,6 @@ func getCPUMem() (memInfo, error) {
return mem, nil
}
func getCPUMemByCgroups(mem memInfo) memInfo {
total, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.max")
if err == nil {
mem.TotalMemory = total
}
used, err := getUint64ValueFromFile("/sys/fs/cgroup/memory.current")
if err == nil {
mem.FreeMemory = mem.TotalMemory - used
}
return mem
}
func getUint64ValueFromFile(path string) (uint64, error) {
f, err := os.Open(path)
if err != nil {
return 0, err
}
defer f.Close()
s := bufio.NewScanner(f)
for s.Scan() {
line := s.Text()
return strconv.ParseUint(line, 10, 64)
}
return 0, errors.New("empty file content")
}
const CpuInfoFilename = "/proc/cpuinfo"
type linuxCpuInfo struct {
@@ -110,41 +74,7 @@ func GetCPUDetails() []CPU {
return nil
}
defer file.Close()
cpus := linuxCPUDetails(file)
return overwriteThreadCountByLinuxCgroups(cpus)
}
func overwriteThreadCountByLinuxCgroups(cpus []CPU) []CPU {
file, err := os.Open("/sys/fs/cgroup/cpu.max")
if err != nil {
return cpus
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := scanner.Text()
if sl := strings.Split(line, " "); len(sl) == 2 {
allowdUs, err := strconv.ParseInt(sl[0], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU allowed micro secs", "error", err)
return cpus
}
unitUs, err := strconv.ParseInt(sl[1], 10, 64)
if err != nil {
slog.Warn("failed to parse CPU unit micro secs", "error", err)
return cpus
}
threads := int(max(allowdUs/unitUs, 1))
cpu := cpus[0]
cpu.CoreCount = threads
cpu.ThreadCount = threads
return []CPU{cpu}
}
}
return cpus
return linuxCPUDetails(file)
}
func linuxCPUDetails(file io.Reader) []CPU {

View File

@@ -65,11 +65,6 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
}
slog.Info("discovering available GPUs...")
detectIncompatibleLibraries()
// Warn if any user-overrides are set which could lead to incorrect GPU discovery
overrideWarnings()
requested := envconfig.LLMLibrary()
jetpack := cudaJetpack()
@@ -95,16 +90,10 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
var dirs []string
if dir != "" {
if requested != "" && filepath.Base(dir) != requested {
slog.Debug("skipping available library at user's request", "requested", requested, "libDir", dir)
slog.Debug("skipping available library at users request", "requested", requested, "libDir", dir)
continue
} else if jetpack != "" && filepath.Base(dir) != "cuda_"+jetpack {
continue
} else if jetpack == "" && strings.Contains(filepath.Base(dir), "cuda_jetpack") {
slog.Debug("jetpack not detected (set JETSON_JETPACK or OLLAMA_LLM_LIBRARY to override), skipping", "libDir", dir)
continue
} else if !envconfig.EnableVulkan() && strings.Contains(filepath.Base(dir), "vulkan") {
slog.Info("experimental Vulkan support disabled. To enable, set OLLAMA_VULKAN=1")
continue
}
dirs = []string{ml.LibOllamaPath, dir}
} else {
@@ -121,7 +110,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
// In the second pass, we more deeply initialize the GPUs to weed out devices that
// aren't supported by a given library. We run this phase in parallel to speed up discovery.
// Only devices that need verification are included in this pass
slog.Debug("evaluating which, if any, devices to filter out", "initial_count", len(devices))
slog.Debug("evluating which if any devices to filter out", "initial_count", len(devices))
ctx2ndPass, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
var wg sync.WaitGroup
@@ -129,25 +118,15 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
supportedMu := sync.Mutex{}
supported := make(map[string]map[string]map[string]int) // [Library][libDir][ID] = pre-deletion devices index
for i := range devices {
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
if !devices[i].NeedsInitValidation() {
// No need to validate, add to the supported map
supportedMu.Lock()
if _, ok := supported[devices[i].Library]; !ok {
supported[devices[i].Library] = make(map[string]map[string]int)
}
if _, ok := supported[devices[i].Library][libDir]; !ok {
supported[devices[i].Library][libDir] = make(map[string]int)
}
supported[devices[i].Library][libDir][devices[i].ID] = i
supportedMu.Unlock()
continue
}
slog.Debug("verifying if device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
libDir := devices[i].LibraryPath[len(devices[i].LibraryPath)-1]
slog.Debug("verifying device is supported", "library", libDir, "description", devices[i].Description, "compute", devices[i].Compute(), "id", devices[i].ID, "pci_id", devices[i].PCIID)
wg.Add(1)
go func(i int) {
defer wg.Done()
extraEnvs := ml.GetVisibleDevicesEnv(devices[i:i+1], true)
extraEnvs := ml.GetVisibleDevicesEnv(devices[i : i+1])
devices[i].AddInitValidation(extraEnvs)
if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 {
slog.Debug("filtering device which didn't fully initialize",
@@ -333,8 +312,7 @@ func GPUDevices(ctx context.Context, runners []ml.FilteredRunnerDiscovery) []ml.
defer cancel()
// Apply any dev filters to avoid re-discovering unsupported devices, and get IDs correct
// We avoid CUDA filters here to keep ROCm from failing to discover GPUs in a mixed environment
devFilter := ml.GetVisibleDevicesEnv(devices, false)
devFilter := ml.GetVisibleDevicesEnv(devices)
for dir := range libDirs {
updatedDevices := bootstrapDevices(ctx, []string{ml.LibOllamaPath, dir}, devFilter)
@@ -468,37 +446,3 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs map
return devices
}
func overrideWarnings() {
anyFound := false
m := envconfig.AsMap()
for _, k := range []string{
"CUDA_VISIBLE_DEVICES",
"HIP_VISIBLE_DEVICES",
"ROCR_VISIBLE_DEVICES",
"GGML_VK_VISIBLE_DEVICES",
"GPU_DEVICE_ORDINAL",
"HSA_OVERRIDE_GFX_VERSION",
} {
if e, found := m[k]; found && e.Value != "" {
anyFound = true
slog.Warn("user overrode visible devices", k, e.Value)
}
}
if anyFound {
slog.Warn("if GPUs are not correctly discovered, unset and try again")
}
}
func detectIncompatibleLibraries() {
if runtime.GOOS != "windows" {
return
}
basePath, err := exec.LookPath("ggml-base.dll")
if err != nil || basePath == "" {
return
}
if !strings.HasPrefix(basePath, ml.LibOllamaPath) {
slog.Warn("potentially incompatible library detected in PATH", "location", basePath)
}
}

View File

@@ -12,7 +12,7 @@
### Reference
* [API Reference](https://docs.ollama.com/api)
* [Modelfile Reference](https://docs.ollama.com/modelfile)
* [Modelfile Reference](./modelfile.md)
* [OpenAI Compatibility](https://docs.ollama.com/api/openai-compatibility)
### Resources

View File

@@ -13,23 +13,9 @@ Embeddings turn text into numeric vectors you can store in a vector database, se
## Generate embeddings
Use `/api/embed` with a single string.
<Tabs>
<Tab title="CLI">
Generate embeddings directly from the command line:
```shell
ollama run embeddinggemma "Hello world"
```
You can also pipe text to generate embeddings:
```shell
echo "Hello world" | ollama run embeddinggemma
```
Output is a JSON array.
</Tab>
<Tab title="cURL">
```shell
curl -X POST http://localhost:11434/api/embed \

View File

@@ -15,7 +15,7 @@ Also known as "single-shot" tool calling.
```shell
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3",
"messages": [{"role": "user", "content": "What is the temperature in New York?"}],
"messages": [{"role": "user", "content": "What's the temperature in New York?"}],
"stream": false,
"tools": [
{
@@ -41,7 +41,7 @@ Also known as "single-shot" tool calling.
curl -s http://localhost:11434/api/chat -H "Content-Type: application/json" -d '{
"model": "qwen3",
"messages": [
{"role": "user", "content": "What is the temperature in New York?"},
{"role": "user", "content": "What's the temperature in New York?"},
{
"role": "assistant",
"tool_calls": [
@@ -90,7 +90,7 @@ Also known as "single-shot" tool calling.
}
return temperatures.get(city, "Unknown")
messages = [{"role": "user", "content": "What is the temperature in New York?"}]
messages = [{"role": "user", "content": "What's the temperature in New York?"}]
# pass functions directly as tools in the tools list or as a JSON schema
response = chat(model="qwen3", messages=messages, tools=[get_temperature], think=True)
@@ -146,7 +146,7 @@ Also known as "single-shot" tool calling.
},
]
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
const response = await ollama.chat({
model: 'qwen3',
@@ -609,7 +609,7 @@ def get_temperature(city: str) -> str:
return temperatures.get(city, 'Unknown')
messages = [{'role': 'user', 'content': "What is the temperature in New York?"}]
messages = [{'role': 'user', 'content': "What's the temperature in New York?"}]
while True:
stream = chat(
@@ -684,7 +684,7 @@ const getTemperatureTool = {
}
async function agentLoop() {
const messages = [{ role: 'user', content: "What is the temperature in New York?" }]
const messages = [{ role: 'user', content: "What's the temperature in New York?" }]
while (true) {
const stream = await ollama.chat({

View File

@@ -9,9 +9,15 @@ sidebarTitle: Cloud
Ollama's cloud models are a new kind of model in Ollama that can run without a powerful GPU. Instead, cloud models are automatically offloaded to Ollama's cloud service while offering the same capabilities as local models, making it possible to keep using your local tools while running larger models that wouldn't fit on a personal computer.
### Supported models
Ollama currently supports the following cloud models, with more coming soon:
For a list of supported models, see Ollama's [model library](https://ollama.com/search?c=cloud).
- `deepseek-v3.1:671b-cloud`
- `gpt-oss:20b-cloud`
- `gpt-oss:120b-cloud`
- `kimi-k2:1t-cloud`
- `qwen3-coder:480b-cloud`
- `glm-4.6:cloud`
- `minimax-m2:cloud`
### Running Cloud models

View File

@@ -49,8 +49,6 @@ Install prerequisites:
- [Ninja](https://github.com/ninja-build/ninja/releases)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_network)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
Then, configure and build the project:
@@ -59,17 +57,6 @@ cmake -B build
cmake --build build --config Release
```
> Building for Vulkan requires VULKAN_SDK environment variable:
>
> PowerShell
> ```powershell
> $env:VULKAN_SDK="C:\VulkanSDK\<version>"
> ```
> CMD
> ```cmd
> set VULKAN_SDK=C:\VulkanSDK\<version>
> ```
> [!IMPORTANT]
> Building for ROCm requires additional flags:
> ```
@@ -78,7 +65,6 @@ cmake --build build --config Release
> ```
Lastly, run Ollama:
```shell
@@ -98,9 +84,7 @@ Install prerequisites:
- [ROCm](https://rocm.docs.amd.com/projects/install-on-linux/en/latest/install/quick-start.html)
- (Optional) NVIDIA GPU support
- [CUDA SDK](https://developer.nvidia.com/cuda-downloads)
- (Optional) VULKAN GPU support
- [VULKAN SDK](https://vulkan.lunarg.com/sdk/home) - useful for AMD/Intel GPUs
- Or install via package manager: `sudo apt install vulkan-sdk` (Ubuntu/Debian) or `sudo dnf install vulkan-sdk` (Fedora/CentOS)
> [!IMPORTANT]
> Ensure prerequisites are in `PATH` before running CMake.

View File

@@ -68,15 +68,6 @@ To run Ollama using Docker with AMD GPUs, use the `rocm` tag and the following c
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama:rocm
```
## Vulkan Support
Vulkan is bundled into the `ollama/ollama` image.
```shell
docker run -d --device /dev/kfd --device /dev/dri -v ollama:/root/.ollama -p 11434:11434 -e OLLAMA_VULKAN=1 --name ollama ollama/ollama
```
## Run model locally
Now you can run a model:
@@ -88,4 +79,3 @@ docker exec -it ollama ollama run llama3.2
## Try different models
More models can be found on the [Ollama library](https://ollama.com/library).

View File

@@ -63,10 +63,6 @@
{
"source": "/api/openai",
"destination": "/api/openai-compatibility"
},
{
"source": "/api",
"destination": "/api/introduction"
}
],
"navigation": {
@@ -134,7 +130,7 @@
{
"group": "API Reference",
"pages": [
"/api/introduction",
"/api/index",
"/api/authentication",
"/api/streaming",
"/api/usage",

View File

@@ -57,13 +57,8 @@ ollama ps
```
<Info>
**Output**:
```
NAME ID SIZE PROCESSOR UNTIL
llama3:70b bcfb190ca3a7 42 GB 100% GPU 4 minutes from now
```
**Output**: ``` NAME ID SIZE PROCESSOR UNTIL llama3:70b bcfb190ca3a7 42 GB
100% GPU 4 minutes from now ```
</Info>
The `Processor` column will show which memory the model was loaded in to:
@@ -228,7 +223,7 @@ Refer to the section [above](#how-do-i-configure-ollama-server) for how to set e
## How can I use Ollama in Visual Studio Code?
There is already a large collection of plugins available for VS Code as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
There is already a large collection of plugins available for VSCode as well as other editors that leverage Ollama. See the list of [extensions & plugins](https://github.com/ollama/ollama#extensions--plugins) at the bottom of the main repository readme.
## How do I use Ollama with GPU acceleration in Docker?
@@ -381,13 +376,3 @@ ollama signin
<Note>
Replace &lt;username&gt; with your actual Windows user name.
</Note>
## How can I stop Ollama from starting when I login to my computer
Ollama for Windows and macOS register as a login item during installation. You can disable this if you prefer not to have Ollama automatically start. Ollama will respect this setting across upgrades, unless you uninstall the application.
**Windows**
- In `Task Manager` go to the `Startup apps` tab, search for `ollama` then click `Disable`
**MacOS**
- Open `Settings` and search for "Login Items", find the `Ollama` entry under "Allow in the Background`, then click the slider to disable.

View File

@@ -3,35 +3,34 @@ title: Hardware support
---
## Nvidia
Ollama supports Nvidia GPUs with compute capability 5.0+ and driver version 531 and newer.
Ollama supports Nvidia GPUs with compute capability 5.0+.
Check your compute compatibility to see if your card is supported:
[https://developer.nvidia.com/cuda-gpus](https://developer.nvidia.com/cuda-gpus)
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ------------------------------------------------------------------------------------------------------------------------------ |
| 12.0 | GeForce RTX 50xx | `RTX 5060` `RTX 5060 Ti` `RTX 5070` `RTX 5070 Ti` `RTX 5080` `RTX 5090` |
| | NVIDIA Professional | `RTX PRO 4000 Blackwell` `RTX PRO 4500 Blackwell` `RTX PRO 5000 Blackwell` `RTX PRO 6000 Blackwell` |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
| 8.0 | NVIDIA | `A100` `A30` |
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
| | Tesla | `P40` `P4` |
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
| 5.2 | GeForce GTX | `GTX TITAN X` `GTX 980 Ti` `GTX 980` `GTX 970` `GTX 960` `GTX 950` |
| | Quadro | `M6000 24GB` `M6000` `M5000` `M5500M` `M4000` `M2200` `M2000` `M620` |
| | Tesla | `M60` `M40` |
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
| Compute Capability | Family | Cards |
| ------------------ | ------------------- | ----------------------------------------------------------------------------------------------------------------------------- |
| 9.0 | NVIDIA | `H200` `H100` |
| 8.9 | GeForce RTX 40xx | `RTX 4090` `RTX 4080 SUPER` `RTX 4080` `RTX 4070 Ti SUPER` `RTX 4070 Ti` `RTX 4070 SUPER` `RTX 4070` `RTX 4060 Ti` `RTX 4060` |
| | NVIDIA Professional | `L4` `L40` `RTX 6000` |
| 8.6 | GeForce RTX 30xx | `RTX 3090 Ti` `RTX 3090` `RTX 3080 Ti` `RTX 3080` `RTX 3070 Ti` `RTX 3070` `RTX 3060 Ti` `RTX 3060` `RTX 3050 Ti` `RTX 3050` |
| | NVIDIA Professional | `A40` `RTX A6000` `RTX A5000` `RTX A4000` `RTX A3000` `RTX A2000` `A10` `A16` `A2` |
| 8.0 | NVIDIA | `A100` `A30` |
| 7.5 | GeForce GTX/RTX | `GTX 1650 Ti` `TITAN RTX` `RTX 2080 Ti` `RTX 2080` `RTX 2070` `RTX 2060` |
| | NVIDIA Professional | `T4` `RTX 5000` `RTX 4000` `RTX 3000` `T2000` `T1200` `T1000` `T600` `T500` |
| | Quadro | `RTX 8000` `RTX 6000` `RTX 5000` `RTX 4000` |
| 7.0 | NVIDIA | `TITAN V` `V100` `Quadro GV100` |
| 6.1 | NVIDIA TITAN | `TITAN Xp` `TITAN X` |
| | GeForce GTX | `GTX 1080 Ti` `GTX 1080` `GTX 1070 Ti` `GTX 1070` `GTX 1060` `GTX 1050 Ti` `GTX 1050` |
| | Quadro | `P6000` `P5200` `P4200` `P3200` `P5000` `P4000` `P3000` `P2200` `P2000` `P1000` `P620` `P600` `P500` `P520` |
| | Tesla | `P40` `P4` |
| 6.0 | NVIDIA | `Tesla P100` `Quadro GP100` |
| 5.2 | GeForce GTX | `GTX TITAN X` `GTX 980 Ti` `GTX 980` `GTX 970` `GTX 960` `GTX 950` |
| | Quadro | `M6000 24GB` `M6000` `M5000` `M5500M` `M4000` `M2200` `M2000` `M620` |
| | Tesla | `M60` `M40` |
| 5.0 | GeForce GTX | `GTX 750 Ti` `GTX 750` `NVS 810` |
| | Quadro | `K2200` `K1200` `K620` `M1200` `M520` `M5000M` `M4000M` `M3000M` `M2000M` `M1000M` `K620M` `M600M` `M500M` |
For building locally to support older GPUs, see [developer.md](./development.md#linux-cuda-nvidia)
@@ -52,28 +51,24 @@ sudo modprobe nvidia_uvm`
## AMD Radeon
Ollama supports the following AMD GPUs via the ROCm library:
> [!NOTE]
> Additional AMD GPU support is provided by the Vulkan Library - see below.
Ollama supports the following AMD GPUs:
### Linux Support
| Family | Cards and accelerators |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` `Vega 64` `Vega 56` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` `V420` `V340` `V320` `Vega II Duo` `Vega II` `VII` `SSG` |
| AMD Instinct | `MI300X` `MI300A` `MI300` `MI250X` `MI250` `MI210` `MI200` `MI100` `MI60` `MI50` |
### Windows Support
With ROCm v6.1, the following GPUs are supported on Windows.
| Family | Cards and accelerators |
| -------------- | -------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
| Family | Cards and accelerators |
| -------------- | ------------------------------------------------------------------------------------------------------------------- |
| AMD Radeon RX | `7900 XTX` `7900 XT` `7900 GRE` `7800 XT` `7700 XT` `7600 XT` `7600` `6950 XT` `6900 XTX` `6900XT` `6800 XT` `6800` |
| AMD Radeon PRO | `W7900` `W7800` `W7700` `W7600` `W7500` `W6900X` `W6800X Duo` `W6800X` `W6800` `V620` |
### Overrides on Linux
@@ -95,6 +90,8 @@ At this time, the known supported GPU types on linux are the following LLVM Targ
This table shows some example GPUs that map to these LLVM targets:
| **LLVM Target** | **An Example GPU** |
|-----------------|---------------------|
| gfx900 | Radeon RX Vega 56 |
| gfx906 | Radeon Instinct MI50 |
| gfx908 | Radeon Instinct MI100 |
| gfx90a | Radeon Instinct MI210 |
| gfx940 | Radeon Instinct MI300 |
@@ -125,42 +122,6 @@ In some Linux distributions, SELinux can prevent containers from
accessing the AMD GPU devices. On the host system you can run
`sudo setsebool container_use_devices=1` to allow containers to use devices.
## Metal (Apple GPUs)
### Metal (Apple GPUs)
Ollama supports GPU acceleration on Apple devices via the Metal API.
## Vulkan GPU Support
> [!NOTE]
> Vulkan is currently an Experimental feature. To enable, you must set OLLAMA_VULKAN=1 for the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server)
Additional GPU support on Windows and Linux is provided via
[Vulkan](https://www.vulkan.org/). On Windows most GPU vendors drivers come
bundled with Vulkan support and require no additional setup steps. Most Linux
distributions require installing additional components, and you may have
multiple options for Vulkan drivers between Mesa and GPU Vendor specific packages
- Linux Intel GPU Instructions - https://dgpu-docs.intel.com/driver/client/overview.html
- Linux AMD GPU Instructions - https://amdgpu-install.readthedocs.io/en/latest/install-script.html#specifying-a-vulkan-implementation
For AMD GPUs on some Linux distributions, you may need to add the `ollama` user to the `render` group.
The Ollama scheduler leverages available VRAM data reported by the GPU libraries to
make optimal scheduling decisions. Vulkan requires additional capabilities or
running as root to expose this available VRAM data. If neither root access or this
capability are granted, Ollama will use approximate sizes of the models
to make best effort scheduling decisions.
```bash
sudo setcap cap_perfmon+ep /usr/local/bin/ollama
```
### GPU Selection
To select specific Vulkan GPU(s), you can set the environment variable
`GGML_VK_VISIBLE_DEVICES` to one or more numeric IDs on the Ollama server as
described in the [FAQ](faq.md#how-do-i-configure-ollama-server). If you
encounter any problems with Vulkan based GPUs, you can disable all Vulkan GPUs
by setting `GGML_VK_VISIBLE_DEVICES=-1`

View File

@@ -25,23 +25,8 @@ Install [n8n](https://docs.n8n.io/choose-n8n/).
width="75%"
/>
</div>
3. Confirm Base URL is set to `http://localhost:11434` if running locally or `http://host.docker.internal:11434` if running through docker and click **Save**
<Note>
In environments that don't use Docker Desktop (ie, Linux server installations), `host.docker.internal` is not automatically added.
Run n8n in docker with `--add-host=host.docker.internal:host-gateway`
or add the following to a docker compose file:
```yaml
extra_hosts:
- "host.docker.internal:host-gateway"
```
</Note>
You should see a `Connection tested successfully` message.
3. Confirm Base URL is set to `http://localhost:11434` and click **Save**
<Note> If connecting to `http://localhost:11434` fails, use `http://127.0.0.1:11434`</Note>
4. When creating a new workflow, select **Add a first step** and select an **Ollama node**
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img

View File

@@ -1,34 +1,34 @@
---
title: VS Code
title: VS Code
---
## Install
Install [VS Code](https://code.visualstudio.com/download).
Install [VSCode](https://code.visualstudio.com/download).
## Usage with Ollama
## Usage with Ollama
1. Open Copilot side bar found in top right window
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-sidebar.png"
alt="VS Code chat Sidebar"
width="75%"
/>
</div>
2. Select the model dropdown > **Manage models**
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-models.png"
alt="VS Code model picker"
width="75%"
/>
</div>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-sidebar.png"
alt="VSCode chat Sidebar"
width="75%"
/>
</div>
2. Select the model drowpdown > **Manage models**
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-models.png"
alt="VSCode model picker"
width="75%"
/>
</div>
3. Enter **Ollama** under **Provider Dropdown** and select desired models (e.g `qwen3, qwen3-coder:480b-cloud`)
<div style={{ display: "flex", justifyContent: "center" }}>
<img
src="/images/vscode-model-options.png"
alt="VS Code model options dropdown"
width="75%"
/>
</div>
<div style={{ display: 'flex', justifyContent: 'center' }}>
<img
src="/images/vscode-model-options.png"
alt="VSCode model options dropdown"
width="75%"
/>
</div>

View File

@@ -149,6 +149,9 @@ PARAMETER <parameter> <parametervalue>
| Parameter | Description | Value Type | Example Usage |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | ---------- | -------------------- |
| mirostat | Enable Mirostat sampling for controlling perplexity. (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0) | int | mirostat 0 |
| mirostat_eta | Influences how quickly the algorithm responds to feedback from the generated text. A lower learning rate will result in slower adjustments, while a higher learning rate will make the algorithm more responsive. (Default: 0.1) | float | mirostat_eta 0.1 |
| mirostat_tau | Controls the balance between coherence and diversity of the output. A lower value will result in more focused and coherent text. (Default: 5.0) | float | mirostat_tau 5.0 |
| num_ctx | Sets the size of the context window used to generate the next token. (Default: 2048) | int | num_ctx 4096 |
| repeat_last_n | Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx) | int | repeat_last_n 64 |
| repeat_penalty | Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1) | float | repeat_penalty 1.1 |

View File

@@ -2,15 +2,12 @@ openapi: 3.1.0
info:
title: Ollama API
version: 0.1.0
license:
name: MIT
url: https://opensource.org/licenses/MIT
description: |
OpenAPI specification for the Ollama HTTP API
servers:
- url: http://localhost:11434
description: Ollama
security: []
description: Local Ollama instance
components:
securitySchemes:
bearerAuth:
@@ -96,11 +93,8 @@ components:
type: boolean
default: true
think:
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
type: boolean
description: When true, returns separate thinking output in addition to content
raw:
type: boolean
description: When true, returns the raw response from the model without any prompt templating
@@ -111,12 +105,6 @@ components:
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
options:
$ref: "#/components/schemas/ModelOptions"
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
GenerateResponse:
type: object
properties:
@@ -156,11 +144,6 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
GenerateStreamEvent:
type: object
properties:
@@ -288,22 +271,13 @@ components:
type: boolean
default: true
think:
oneOf:
- type: boolean
- type: string
enum: [high, medium, low]
description: When true, returns separate thinking output in addition to content. Can be a boolean (true/false) or a string ("high", "medium", "low") for supported models.
type: boolean
description: When true, returns separate thinking output in addition to content
keep_alive:
oneOf:
- type: string
- type: number
description: Model keep-alive duration (for example `5m` or `0` to unload immediately)
logprobs:
type: boolean
description: Whether to return log probabilities of the output tokens
top_logprobs:
type: integer
description: Number of most likely tokens to return at each token position when logprobs are enabled
ChatResponse:
type: object
properties:
@@ -336,6 +310,7 @@ components:
type: array
items:
type: string
nullable: true
description: Optional base64-encoded images in the response
done:
type: boolean
@@ -361,11 +336,6 @@ components:
eval_duration:
type: integer
description: Time spent generating tokens in nanoseconds
logprobs:
type: array
items:
$ref: "#/components/schemas/Logprob"
description: Log probability information for the generated tokens when logprobs are enabled
ChatStreamEvent:
type: object
properties:
@@ -397,6 +367,7 @@ components:
type: array
items:
type: string
nullable: true
description: Partial base64-encoded images, when present
done:
type: boolean
@@ -572,9 +543,6 @@ components:
license:
type: string
description: The license of the model
modified_at:
type: string
description: Last modified timestamp in ISO 8601 format
details:
type: object
description: High-level model details
@@ -654,9 +622,6 @@ components:
size_vram:
type: integer
description: VRAM usage in bytes
context_length:
type: integer
description: Context length for the running model
PsResponse:
type: object
properties:
@@ -728,41 +693,6 @@ components:
version:
type: string
description: Version of Ollama
TokenLogprob:
type: object
description: Log probability information for a single token alternative
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
Logprob:
type: object
description: Log probability information for a generated token
properties:
token:
type: string
description: The text representation of the token
logprob:
type: number
description: The log probability of this token
bytes:
type: array
items:
type: integer
description: The raw byte representation of the token
top_logprobs:
type: array
items:
$ref: "#/components/schemas/TokenLogprob"
description: Most likely tokens and their log probabilities at this position
ErrorResponse:
type: object
properties:
@@ -1345,9 +1275,6 @@ paths:
example:
source: gemma3
destination: gemma3-backup
responses:
"200":
description: Model successfully copied
/api/pull:
post:
summary: Pull a model
@@ -1455,7 +1382,16 @@ paths:
model: gemma3
responses:
"200":
description: Model successfully deleted
description: Deletion status updates.
content:
application/json:
schema:
$ref: "#/components/schemas/StatusResponse"
example:
status: "success"
application/x-ndjson:
schema:
$ref: "#/components/schemas/StatusEvent"
/api/version:
get:
summary: Get version

View File

@@ -196,6 +196,8 @@ var (
NoPrune = Bool("OLLAMA_NOPRUNE")
// SchedSpread allows scheduling models across all GPUs.
SchedSpread = Bool("OLLAMA_SCHED_SPREAD")
// IntelGPU enables experimental Intel GPU detection.
IntelGPU = Bool("OLLAMA_INTEL_GPU")
// MultiUserCache optimizes prompt caching for multi-user scenarios
MultiUserCache = Bool("OLLAMA_MULTIUSER_CACHE")
// Enable the new Ollama engine
@@ -204,8 +206,6 @@ var (
ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096)
// Auth enables authentication between the Ollama client and server
UseAuth = Bool("OLLAMA_AUTH")
// Enable Vulkan backend
EnableVulkan = Bool("OLLAMA_VULKAN")
)
func String(s string) func() string {
@@ -314,7 +314,7 @@ func AsMap() map[string]EnvVar {
ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"}
ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"}
ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"}
ret["OLLAMA_VULKAN"] = EnvVar{"OLLAMA_VULKAN", EnableVulkan(), "Enable experimental Vulkan support"}
ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"}
}
return ret

View File

@@ -249,9 +249,6 @@ func (kv KV) OllamaEngineRequired() bool {
"qwen25vl",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
"deepseekocr",
"deepseek2",
"nomic-bert",
}, kv.Architecture())
}
@@ -800,6 +797,73 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri
return
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
if llm.KV().Uint("vision.block_count") == 0 {
return
}
for name, layer := range llm.Tensors().GroupLayers() {
if name == "v" || strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
imageSize := uint64(llm.KV().Uint("vision.image_size"))
patchSize := uint64(llm.KV().Uint("vision.patch_size"))
if patchSize == 0 {
slog.Warn("unknown patch size for vision model")
return
}
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
numPatches := (imageSize / patchSize) * (imageSize / patchSize)
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++
}
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
graphSize = 4 * (8 +
imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3", "mistral3":
graphSize = 4 * (imageSize*imageSize*numChannels +
embeddingLength*patchSize +
numPatches*numPatches*headCount)
case "qwen25vl":
maxPixels := uint64(llm.KV().Uint("vision.max_pixels", 28*28*1280))
numPatches := maxPixels / (patchSize * patchSize)
graphSize = 4 * (maxPixels*numChannels + // Original image storage
// Normalized pixels
maxPixels*numChannels +
// Patches storage (numPatches * channels * patchSize^2)
numPatches*numChannels*patchSize*patchSize +
// Self-attention calculations
numPatches*numPatches*headCount +
// Additional buffer for processing
embeddingLength*numPatches)
case "llama4":
// vision graph is computed independently in the same schedule
// and is negligible compared to the worst case text graph
}
return weights, graphSize
}
// SupportsKVCacheType checks if the requested cache type is supported
func (f GGML) SupportsKVCacheType(cacheType string) bool {
if cacheType == "" || cacheType == "f16" {
@@ -831,7 +895,6 @@ func (f GGML) FlashAttention() bool {
return slices.Contains([]string{
"gemma3",
"gptoss", "gpt-oss",
"mistral3",
"qwen3", "qwen3moe",
"qwen3vl", "qwen3vlmoe",
}, f.KV().String("general.architecture"))

View File

@@ -305,7 +305,7 @@ func readGGUFV1StringsData(llm *gguf, r io.Reader, a *array[string]) (any, error
a.values[i] = e
} else {
_ = discardGGUFString(llm, r)
discardGGUFString(llm, r)
}
}
@@ -568,6 +568,7 @@ func WriteGGUF(f *os.File, kv KV, ts []*Tensor) error {
g.SetLimit(runtime.GOMAXPROCS(0))
// TODO consider reducing if tensors size * gomaxprocs is larger than free memory
for _, t := range ts {
t := t
w := io.NewOffsetWriter(f, offset+int64(t.Offset))
g.Go(func() error {
_, err := t.WriteTo(w)
@@ -597,10 +598,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
var err error
switch v := v.(type) {
case int32:
err = writeGGUF(ws, ggufTypeInt32, v)
case int64:
err = writeGGUF(ws, ggufTypeInt64, v)
case uint32, FileType:
err = writeGGUF(ws, ggufTypeUint32, v)
case uint64:
@@ -615,10 +612,6 @@ func ggufWriteKV(ws io.WriteSeeker, arch, k string, v any) error {
err = writeGGUFArray(ws, ggufTypeInt32, v)
case *array[int32]:
err = writeGGUFArray(ws, ggufTypeInt32, v.values)
case []int64:
err = writeGGUFArray(ws, ggufTypeInt64, v)
case *array[int64]:
err = writeGGUFArray(ws, ggufTypeInt64, v.values)
case []uint32:
err = writeGGUFArray(ws, ggufTypeUint32, v)
case *array[uint32]:

View File

@@ -42,10 +42,6 @@ func TestWriteGGUF(t *testing.T) {
"general.architecture": "test",
"general.alignment": uint32(16),
"test.key": "value",
"test.int32_key": int32(-42),
"test.int64_key": int64(-9223372036854775808),
"test.int32_array": []int32{-1, 0, 1, 2147483647, -2147483648},
"test.int64_array": []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808},
"attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
@@ -59,7 +55,7 @@ func TestWriteGGUF(t *testing.T) {
}
defer r.Close()
ff, err := Decode(r, -1)
ff, err := Decode(r, 0)
if err != nil {
t.Fatal(err)
}
@@ -69,19 +65,15 @@ func TestWriteGGUF(t *testing.T) {
"general.alignment": uint32(16),
"general.parameter_count": uint64(54),
"test.key": "value",
"test.int32_key": int32(-42),
"test.int64_key": int64(-9223372036854775808),
"test.int32_array": &array[int32]{size: 5, values: []int32{-1, 0, 1, 2147483647, -2147483648}},
"test.int64_array": &array[int64]{size: 5, values: []int64{-1, 0, 1, 9223372036854775807, -9223372036854775808}},
"test.attention.key": "value2",
"tokenizer.key": "value3",
"adapter.key": "value4",
}, ff.KV(), cmp.AllowUnexported(array[int32]{}, array[int64]{})); diff != "" {
}, ff.KV()); diff != "" {
t.Errorf("Mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(Tensors{
Offset: 992,
Offset: 800,
items: []*Tensor{
{Name: "blk.0.attn_k.weight", Offset: 0, Shape: []uint64{2, 3}},
{Name: "blk.0.attn_norm.weight", Offset: 32, Shape: []uint64{2, 3}},

1
go.mod
View File

@@ -17,6 +17,7 @@ require (
github.com/x448/float16 v0.8.4
golang.org/x/sync v0.12.0
golang.org/x/sys v0.36.0
)
require (

View File

@@ -388,9 +388,9 @@ func NewFunctionNameMap() *FunctionNameMap {
}
}
// Init initializes the handler with tools, optional last message, and think value
// Init initializes the handler with tools and optional last message
// Implements the Parser interface
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message, thinkValue *api.ThinkValue) []api.Tool {
func (h *HarmonyMessageHandler) Init(tools []api.Tool, lastMessage *api.Message) []api.Tool {
// Initialize the harmony parser
if h.HarmonyParser == nil {
h.HarmonyParser = &HarmonyParser{

View File

@@ -14,23 +14,6 @@ import (
"github.com/ollama/ollama/api"
)
func assertBytesMatchToken(t *testing.T, label, token string, ints []int) {
t.Helper()
raw := []byte(token)
if len(ints) != len(raw) {
t.Errorf("%s expected %d bytes for token %q, got %d (%v)", label, len(raw), token, len(ints), ints)
return
}
for i, b := range raw {
if ints[i] != int(b) {
t.Errorf("%s byte[%d] mismatch for token %q: got %d want %d", label, i, token, ints[i], int(b))
return
}
}
}
func TestAPIGenerate(t *testing.T) {
initialTimeout := 60 * time.Second
streamTimeout := 30 * time.Second
@@ -398,182 +381,3 @@ func TestAPIShowModel(t *testing.T) {
t.Errorf("%s missing modified_at: %#v", modelName, resp)
}
}
func TestAPIGenerateLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
tests := []struct {
name string
logprobs *bool
topLogprobs int
expectCount int
}{
{
name: "no_logprobs",
logprobs: nil,
topLogprobs: 0,
expectCount: 0,
},
{
name: "logprobs_only",
logprobs: &enableLogprobs,
topLogprobs: 0,
expectCount: 1,
},
{
name: "logprobs_with_top_5",
logprobs: &enableLogprobs,
topLogprobs: 5,
expectCount: 1,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
req := api.GenerateRequest{
Model: smol,
Prompt: "Why is the sky blue?",
Stream: &noStream,
Logprobs: test.logprobs != nil && *test.logprobs,
TopLogprobs: test.topLogprobs,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 10,
},
}
var response api.GenerateResponse
err := client.Generate(ctx, &req, func(resp api.GenerateResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("generate failed: %s", err)
}
// Check logprobs based on expectation
if test.expectCount == 0 {
if len(response.Logprobs) > 0 {
t.Errorf("expected no logprobs but got %d", len(response.Logprobs))
}
} else {
if len(response.Logprobs) == 0 {
t.Errorf("expected logprobs but got none")
}
// Validate each logprob entry
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f (should be <= 0)", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d]", i), lp.Token, lp.Bytes)
// Check top_logprobs if requested
if test.topLogprobs > 0 {
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > test.topLogprobs {
t.Errorf("logprob[%d] has %d top_logprobs, expected max %d", i, len(lp.TopLogprobs), test.topLogprobs)
}
// Verify top_logprobs are sorted by probability (descending)
for j := 1; j < len(lp.TopLogprobs); j++ {
if lp.TopLogprobs[j-1].Logprob < lp.TopLogprobs[j].Logprob {
t.Errorf("logprob[%d].top_logprobs not sorted: %f < %f", i, lp.TopLogprobs[j-1].Logprob, lp.TopLogprobs[j].Logprob)
}
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("generate logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
} else if len(lp.TopLogprobs) > 0 {
t.Errorf("logprob[%d] has top_logprobs but none were requested", i)
}
}
}
})
}
}
func TestAPIChatLogprobs(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
if err := PullIfMissing(ctx, client, smol); err != nil {
t.Fatalf("pull failed %s", err)
}
enableLogprobs := true
noStream := false
req := api.ChatRequest{
Model: smol,
Messages: []api.Message{
{Role: "user", Content: "Say hello in one word"},
},
Stream: &noStream,
Logprobs: enableLogprobs,
TopLogprobs: 3,
Options: map[string]interface{}{
"temperature": 0,
"seed": 123,
"num_predict": 5,
},
}
var response api.ChatResponse
err := client.Chat(ctx, &req, func(resp api.ChatResponse) error {
if resp.Done {
response = resp
}
return nil
})
if err != nil {
t.Fatalf("chat failed: %s", err)
}
if len(response.Logprobs) == 0 {
t.Fatal("expected logprobs in response but got none")
}
t.Logf("received %d logprobs for chat response", len(response.Logprobs))
for i, lp := range response.Logprobs {
if lp.Token == "" {
t.Errorf("logprob[%d] has empty token", i)
}
if lp.Logprob > 0 {
t.Errorf("logprob[%d] has positive logprob %f", i, lp.Logprob)
}
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d]", i), lp.Token, lp.Bytes)
if len(lp.TopLogprobs) == 0 {
t.Errorf("logprob[%d] expected top_logprobs but got none", i)
}
if len(lp.TopLogprobs) > 3 {
t.Errorf("logprob[%d] has %d top_logprobs, expected max 3", i, len(lp.TopLogprobs))
}
for j, top := range lp.TopLogprobs {
assertBytesMatchToken(t, fmt.Sprintf("chat logprob[%d].top[%d]", i, j), top.Token, top.Bytes)
}
}
}

View File

@@ -4,9 +4,7 @@ package integration
import (
"context"
"errors"
"math"
"strings"
"testing"
"time"
@@ -206,8 +204,8 @@ func TestAllMiniLMEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[0:5], res.Embeddings[0][0:5], sim)
}
if res.PromptEvalCount != 8 {
t.Fatalf("expected 8 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 6 {
t.Fatalf("expected 6 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -253,8 +251,8 @@ func TestAllMiniLMBatchEmbed(t *testing.T) {
t.Fatalf("expected %v, got %v (similarity: %f)", expected[1][0:5], res.Embeddings[1][0:5], sim)
}
if res.PromptEvalCount != 16 {
t.Fatalf("expected 16 prompt tokens, got %d", res.PromptEvalCount)
if res.PromptEvalCount != 12 {
t.Fatalf("expected 12 prompt tokens, got %d", res.PromptEvalCount)
}
}
@@ -277,7 +275,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
cases := []struct {
name string
request api.EmbedRequest
check func(*testing.T, *api.EmbedResponse, error)
check func(*api.EmbedResponse, error)
}{
{
name: "target truncation",
@@ -285,7 +283,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Model: "all-minilm",
Input: "why",
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -302,11 +300,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue?",
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -320,11 +317,10 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, got *api.EmbedResponse, err error) {
check: func(got *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
t.Logf("PromptEvalCount: want=%d got=%d", want.PromptEvalCount, got.PromptEvalCount)
if diff := cmp.Diff(want.Embeddings[0], got.Embeddings[0]); diff != "" {
t.Errorf("embedding mismatch (-want +got):\n%s", diff)
}
@@ -338,21 +334,21 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 3},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
if err.Error() != "the input length exceeds the context length" {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
},
},
{
name: "input after truncate error with context length of 1",
name: "input after truncate error",
request: api.EmbedRequest{
Model: "all-minilm",
Input: "why is the sky blue?",
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 1},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -366,7 +362,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 0},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err.Error() != "input after truncation exceeds maximum context length" {
t.Fatalf("expected truncation error, got: %v", err)
}
@@ -379,7 +375,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
Input: "why is the sky blue? Why is the sky blue? hi there my",
Options: map[string]any{"num_ctx": 16},
},
check: func(t *testing.T, res *api.EmbedResponse, err error) {
check: func(res *api.EmbedResponse, err error) {
if err != nil {
t.Fatal(err)
}
@@ -389,8 +385,7 @@ func TestAllMiniLMEmbedTruncate(t *testing.T) {
for _, req := range cases {
t.Run(req.name, func(t *testing.T) {
resp, err := embedTestHelper(ctx, client, t, req.request)
req.check(t, resp, err)
req.check(embedTestHelper(ctx, client, t, req.request))
})
}
}
@@ -414,173 +409,3 @@ func embedTestHelper(ctx context.Context, client *api.Client, t *testing.T, req
return client.Embed(ctx, &req)
}
func TestEmbedTruncation(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
// Give each model its own budget to account for first-time pulls/loads
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
t.Run("truncation batch", func(t *testing.T) {
truncTrue := true
req := api.EmbedRequest{
Model: model,
Input: []string{"short", strings.Repeat("long ", 100), "medium text"},
Truncate: &truncTrue,
Options: map[string]any{"num_ctx": 30},
}
res, err := embedTestHelper(mctx, client, t, req)
if err != nil {
t.Fatal(err)
}
if len(res.Embeddings) != 3 {
t.Fatalf("expected 3 embeddings, got %d", len(res.Embeddings))
}
if res.PromptEvalCount > 90 {
t.Fatalf("expected tokens <= 90 (3 × 30 max), got %d", res.PromptEvalCount)
}
})
t.Run("runner token count accuracy", func(t *testing.T) {
baseline := api.EmbedRequest{Model: model, Input: "test"}
baseRes, err := embedTestHelper(mctx, client, t, baseline)
if err != nil {
t.Fatal(err)
}
batch := api.EmbedRequest{
Model: model,
Input: []string{"test", "test", "test"},
}
batchRes, err := embedTestHelper(mctx, client, t, batch)
if err != nil {
t.Fatal(err)
}
expectedCount := baseRes.PromptEvalCount * 3
if batchRes.PromptEvalCount < expectedCount-2 || batchRes.PromptEvalCount > expectedCount+2 {
t.Fatalf("expected ~%d tokens (3 × %d), got %d",
expectedCount, baseRes.PromptEvalCount, batchRes.PromptEvalCount)
}
})
})
}
}
// TestEmbedStatusCode tests that errors from the embedding endpoint
// properly preserve their HTTP status codes when returned to the client.
// This test specifically checks the error handling path in EmbedHandler
// where api.StatusError errors should maintain their original status code.
func TestEmbedStatusCode(t *testing.T) {
// Use test deadline if set, otherwise default to 2 minutes
timeout := 2 * time.Minute
if deadline, ok := t.Deadline(); ok {
timeout = time.Until(deadline) - 10*time.Second // Reserve 10s buffer
}
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
client, _, cleanup := InitServerConnection(ctx, t)
defer cleanup()
for _, model := range libraryEmbedModels {
model := model
t.Run(model, func(t *testing.T) {
// Check if we're running out of time (reserve 20s for current model)
if deadline, ok := t.Deadline(); ok && time.Until(deadline) < 20*time.Second {
t.Skip("skipping remaining tests to avoid timeout")
}
mctx, mcancel := context.WithTimeout(ctx, 3*time.Minute)
defer mcancel()
// Pull the model if needed
if err := PullIfMissing(mctx, client, model); err != nil {
t.Fatal(err)
}
t.Run("truncation error status code", func(t *testing.T) {
truncFalse := false
longInput := strings.Repeat("word ", 100)
req := api.EmbedRequest{
Model: model,
Input: longInput,
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when truncate=false with long input")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error (likely 400 Bad Request)
// not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
// Verify the error message is meaningful
if !strings.Contains(err.Error(), "context length") {
t.Errorf("expected error message to mention context length, got: %v", err)
}
})
t.Run("batch truncation error status code", func(t *testing.T) {
truncFalse := false
req := api.EmbedRequest{
Model: model,
Input: []string{
"short input",
strings.Repeat("very long input ", 100),
"another short input",
},
Truncate: &truncFalse,
Options: map[string]any{"num_ctx": 10},
}
_, err := embedTestHelper(mctx, client, t, req)
if err == nil {
t.Fatal("expected error when one input exceeds context with truncate=false")
}
// Check that it's a StatusError with the correct status code
var statusErr api.StatusError
if !errors.As(err, &statusErr) {
t.Fatalf("expected api.StatusError, got %T: %v", err, err)
}
// The error should be a 4xx client error, not a 500 Internal Server Error
if statusErr.StatusCode < 400 || statusErr.StatusCode >= 500 {
t.Errorf("expected 4xx status code, got %d", statusErr.StatusCode)
}
})
})
}
}

View File

@@ -33,9 +33,6 @@ func TestVisionModels(t *testing.T) {
// Qwen 3 VL mixture of experts
model: "qwen3-vl:30b",
},
{
model: "ministral-3",
},
}
for _, v := range testCases {

View File

@@ -30,7 +30,6 @@ func TestAPIToolCalling(t *testing.T) {
"mistral": 6,
"qwen2.5": 6,
"qwen2": 6,
"ministral-3": 20,
"mistral-nemo": 9,
"mistral-small": 16,
"mixtral:8x22b": 80,

View File

@@ -38,7 +38,6 @@ var (
// Note: add newer models at the top of the list to test them first
ollamaEngineChatModels = []string{
"ministral-3",
"qwen3-coder:30b",
"gpt-oss:20b",
"gemma3n:e2b",
@@ -168,7 +167,6 @@ var (
"medllama2",
"megadolphin",
"minicpm-v",
"ministral-3",
"mistral-large",
"mistral-nemo",
"mistral-openorca",
@@ -272,7 +270,6 @@ var (
"mistral",
"qwen2.5",
"qwen2",
"ministral-3",
"mistral-nemo",
"mistral-small",
"mixtral:8x22b",

View File

@@ -3,6 +3,7 @@ package kvcache
import (
"errors"
"fmt"
"log/slog"
"math"
"slices"
@@ -39,18 +40,18 @@ type Causal struct {
// ** current forward pass **
// the active layer for Get and Put
curLayer int
// starting location for data storage for this batch
curLoc int
// size of the current batch
curBatchSize int
// locations for data storage for this batch
curLoc ml.Tensor
// mask of the cache as used by this batch
curMask ml.Tensor
// the active layer for Get and Put
curLayer int
// locations in the cache that are needed for this batch
curCellRange cellRange
@@ -205,47 +206,45 @@ func (c *Causal) StartForward(ctx ml.Context, batch input.Batch, reserve bool) e
c.curPositions = batch.Positions
c.opts.Except = nil
var locs []int32
if !reserve {
c.updateSlidingWindow()
var err error
locs, err = c.findLocs()
c.curLoc, err = c.findStartLoc()
if errors.Is(err, ErrKvCacheFull) {
c.defrag()
c.curLoc, err = c.findStartLoc()
}
if err != nil {
return err
}
for i, pos := range batch.Positions {
seq := batch.Sequences[i]
loc := int(locs[i])
c.cells[loc] = cacheCell{pos: pos, sequences: []int{seq}}
c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}}
seqRange, ok := c.cellRanges[seq]
if !ok {
seqRange = newRange()
}
seqRange.min = min(seqRange.min, loc)
c.curCellRange.min = min(c.curCellRange.min, loc)
seqRange.min = min(seqRange.min, c.curLoc+i)
c.curCellRange.min = min(c.curCellRange.min, c.curLoc+i)
seqRange.max = max(seqRange.max, loc)
c.curCellRange.max = max(c.curCellRange.max, loc)
seqRange.max = max(seqRange.max, c.curLoc+i)
c.curCellRange.max = max(c.curCellRange.max, c.curLoc+i)
c.cellRanges[seq] = seqRange
}
} else {
// If we are reserving memory, don't update any of the cache metadata but set the size
// to the worst case.
locs = make([]int32, c.curBatchSize)
for i := range locs {
locs[i] = int32(i)
}
c.curLoc = 0
c.curCellRange.min = 0
c.curCellRange.max = len(c.cells) - 1
}
c.curLoc = ctx.Input().FromInts(locs, len(locs))
c.curMask = c.buildMask(ctx)
return nil
@@ -258,20 +257,22 @@ func newRange() cellRange {
}
}
// Returns a slice of locations where each token in the batch should be stored
func (c *Causal) findLocs() ([]int32, error) {
loc := make([]int32, 0, c.curBatchSize)
// Find the first contiguous block of at least curBatchSize
func (c *Causal) findStartLoc() (int, error) {
var start, count int
for i := range c.cells {
if len(c.cells[i].sequences) == 0 {
loc = append(loc, int32(i))
if len(loc) >= c.curBatchSize {
return loc, nil
count++
if count >= c.curBatchSize {
return start, nil
}
} else {
start = i + 1
count = 0
}
}
return nil, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
return 0, fmt.Errorf("%w (cache: %v batch: %v)", ErrKvCacheFull, len(c.cells), c.curBatchSize)
}
func (c *Causal) updateSlidingWindow() {
@@ -401,6 +402,145 @@ func (c *Causal) buildMask(ctx ml.Context) ml.Tensor {
return maskTensor
}
func (c *Causal) moveCells(ctx ml.Context, src, dst, length int) {
for i, key := range c.keys {
if key == nil {
continue
}
kHeadDim := key.Dim(0)
numKVHeads := key.Dim(1)
rowSize := key.Stride(2)
kSrcView := key.View(ctx, rowSize*src, kHeadDim*numKVHeads*length)
kDstView := key.View(ctx, rowSize*dst, kHeadDim*numKVHeads*length)
value := c.values[i]
var vSrcView, vDstView ml.Tensor
if c.config.PermutedV {
vHeadDim := value.Dim(1)
elemSize := value.Stride(0)
vSrcView = value.View(ctx, elemSize*src, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
vDstView = value.View(ctx, elemSize*dst, length, len(c.cells)*elemSize, vHeadDim*numKVHeads)
} else {
vHeadDim := value.Dim(0)
rowSize := value.Stride(2)
vSrcView = value.View(ctx, rowSize*src, vHeadDim*numKVHeads*length)
vDstView = value.View(ctx, rowSize*dst, vHeadDim*numKVHeads*length)
}
ctx.Forward(
kSrcView.Copy(ctx, kDstView),
vSrcView.Copy(ctx, vDstView),
)
}
}
func (c *Causal) defrag() {
slog.Debug("defragmenting kv cache")
// Defrag strategy:
// - Search for empty holes at the beginning of the cache,
// filling them with active data starting at the end
// - If there are contiguous elements that need to be moved,
// combine them into a single operation by holding new moves
// until we see that the next one is non-contiguous
// - Fill up the context with the maximum number of operations it
// can hold then compute that and continue with a new context
//
// We could try to optimize placement by grouping blocks from
// the same sequences together but most likely the next forward
// pass will disrupt this anyways, so the real world benefit
// seems limited as this time.
ctx := c.backend.NewContext()
// For every move, 6 tensors are required per layer (2 views and a
// copy for each of k and v). We also need to refer to the original
// k and v cache tensors - once per layer, not per move.
layers := 0
for _, key := range c.keys {
if key == nil {
continue
}
layers++
}
maxMoves := (ctx.MaxGraphNodes() - 2*layers) / (6 * layers)
moves := 0
var pendingSrc, pendingDst, pendingLen int
src := len(c.cells) - 1
for dst := 0; dst < src; dst++ {
if len(c.cells[dst].sequences) == 0 {
for ; src > dst; src-- {
if len(c.cells[src].sequences) != 0 {
c.cells[dst] = c.cells[src]
c.cells[src] = cacheCell{}
if pendingLen > 0 {
if src == pendingSrc-pendingLen && dst == pendingDst+pendingLen {
pendingSrc = src
pendingLen++
break
} else {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
}
pendingSrc = src
pendingDst = dst
pendingLen = 1
break
}
}
}
if moves >= maxMoves {
ctx.Compute()
ctx.Close()
ctx = c.backend.NewContext()
moves = 0
}
}
if pendingLen > 0 {
c.moveCells(ctx, pendingSrc, pendingDst, pendingLen)
moves++
}
if moves > 0 {
ctx.Compute()
}
ctx.Close()
// Reset range metadata
for seq := range c.cellRanges {
seqRange := newRange()
for i, cell := range c.cells {
if slices.Contains(cell.sequences, seq) {
if i < seqRange.min {
seqRange.min = i
}
if i > seqRange.max {
seqRange.max = i
}
}
}
c.cellRanges[seq] = seqRange
}
c.updateSlidingWindow()
}
func (c *Causal) SetLayer(layer int) {
c.curLayer = layer
}
@@ -485,25 +625,18 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
}
}
key = key.Reshape(ctx, kHeadDim*numKVHeads, batchSize)
keyCache := c.keys[c.curLayer]
keyCache = keyCache.Reshape(ctx, kHeadDim*numKVHeads, len(c.cells))
ctx.Forward(keyCache.SetRows(ctx, key, c.curLoc))
rowSize := c.keys[c.curLayer].Stride(2)
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, kHeadDim*numKVHeads*batchSize)))
if c.config.PermutedV {
value = value.Reshape(ctx, vHeadDim*numKVHeads, 1, batchSize)
value = value.Permute(ctx, 2, 0, 1, 3)
elemSize := c.values[c.curLayer].Stride(0)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, 1, len(c.cells), vHeadDim*numKVHeads)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
value = value.Permute(ctx, 1, 2, 0, 3)
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, batchSize, len(c.cells)*elemSize, vHeadDim*numKVHeads)))
} else {
value = value.Reshape(ctx, vHeadDim*numKVHeads, batchSize)
valueCache := c.values[c.curLayer]
valueCache = valueCache.Reshape(ctx, vHeadDim*numKVHeads, len(c.cells))
rowSize := c.values[c.curLayer].Stride(2)
ctx.Forward(valueCache.SetRows(ctx, value, c.curLoc))
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, vHeadDim*numKVHeads*batchSize)))
}
}

File diff suppressed because it is too large Load Diff

2
llama/build-info.cpp generated vendored
View File

@@ -1,4 +1,4 @@
int LLAMA_BUILD_NUMBER = 0;
char const *LLAMA_COMMIT = "17f7f4baad8b3a716ee139da7bb56ae984e8c0fa";
char const *LLAMA_COMMIT = "3cfa9c3f125763305b4226bc032f1954f08990dc";
char const *LLAMA_COMPILER = "";
char const *LLAMA_BUILD_TARGET = "";

View File

@@ -22,9 +22,6 @@ include /src/llama.*
include /src/llama-*.*
include /src/unicode-data.*
include /src/unicode.*
include /src/models/
include /src/models/*.h
include /src/models/*.cpp
include /vendor/
include /vendor/miniaudio/
include /vendor/miniaudio/*.h

View File

@@ -8,7 +8,6 @@
#include "common.h"
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include <algorithm>
#include <cinttypes>
@@ -27,6 +26,7 @@
#include <sstream>
#include <string>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -60,14 +60,6 @@
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
common_time_meas::common_time_meas(int64_t & t_acc, bool disable) : t_start_us(disable ? -1 : ggml_time_us()), t_acc(t_acc) {}
common_time_meas::~common_time_meas() {
if (t_start_us >= 0) {
t_acc += ggml_time_us() - t_start_us;
}
}
//
// CPU utils
//
@@ -363,7 +355,11 @@ bool parse_cpu_mask(const std::string & mask, bool (&boolmask)[GGML_MAX_N_THREAD
}
void common_init() {
llama_log_set(common_log_default_callback, NULL);
llama_log_set([](ggml_log_level level, const char * text, void * /*user_data*/) {
if (LOG_DEFAULT_LLAMA <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}, NULL);
#ifdef NDEBUG
const char * build_type = "";
@@ -694,7 +690,7 @@ bool string_parse_kv_override(const char * data, std::vector<llama_model_kv_over
// Validate if a filename is safe to use
// To validate a full path, split the path by the OS-specific path separator, and validate each part with this function
bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
bool fs_validate_filename(const std::string & filename) {
if (!filename.length()) {
// Empty filename invalid
return false;
@@ -754,14 +750,10 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
|| c == '/' || c == '\\' || c == ':' || c == '*' // Illegal characters
|| c == '?' || c == '"' || c == '<' || c == '>' || c == '|') {
return false;
}
if (!allow_subdirs && (c == '/' || c == '\\')) {
// Subdirectories not allowed, reject path separators
return false;
}
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@@ -786,29 +778,11 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
#include <iostream>
#ifdef _WIN32
static std::wstring utf8_to_wstring(const std::string & str) {
if (str.empty()) {
return std::wstring();
}
int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0);
if (size <= 0) {
return std::wstring();
}
std::wstring wstr(size, 0);
MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size);
return wstr;
}
#endif
// returns true if successful, false otherwise
bool fs_create_directory_with_parents(const std::string & path) {
#ifdef _WIN32
std::wstring wpath = utf8_to_wstring(path);
std::wstring_convert<std::codecvt_utf8<wchar_t>> converter;
std::wstring wpath = converter.from_bytes(path);
// if the path already exists, check whether it's a directory
const DWORD attributes = GetFileAttributesW(wpath.c_str());
@@ -881,11 +855,6 @@ bool fs_create_directory_with_parents(const std::string & path) {
#endif // _WIN32
}
bool fs_is_directory(const std::string & path) {
std::filesystem::path dir(path);
return std::filesystem::exists(dir) && std::filesystem::is_directory(dir);
}
std::string fs_get_cache_directory() {
std::string cache_directory = "";
auto ensure_trailing_slash = [](std::string p) {
@@ -920,8 +889,6 @@ std::string fs_get_cache_directory() {
cache_directory = std::getenv("HOME") + std::string("/Library/Caches/");
#elif defined(_WIN32)
cache_directory = std::getenv("LOCALAPPDATA");
#elif defined(__EMSCRIPTEN__)
GGML_ABORT("not implemented on this platform");
#else
# error Unknown architecture
#endif
@@ -941,130 +908,11 @@ std::string fs_get_cache_file(const std::string & filename) {
return cache_directory + filename;
}
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories) {
std::vector<common_file_info> files;
if (path.empty()) return files;
std::filesystem::path dir(path);
if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) {
return files;
}
for (const auto & entry : std::filesystem::directory_iterator(dir)) {
try {
// Only include regular files (skip directories)
const auto & p = entry.path();
if (std::filesystem::is_regular_file(p)) {
common_file_info info;
info.path = p.string();
info.name = p.filename().string();
info.is_dir = false;
try {
info.size = static_cast<size_t>(std::filesystem::file_size(p));
} catch (const std::filesystem::filesystem_error &) {
info.size = 0;
}
files.push_back(std::move(info));
} else if (include_directories && std::filesystem::is_directory(p)) {
common_file_info info;
info.path = p.string();
info.name = p.filename().string();
info.size = 0; // Directories have no size
info.is_dir = true;
files.push_back(std::move(info));
}
} catch (const std::filesystem::filesystem_error &) {
// skip entries we cannot inspect
continue;
}
}
return files;
}
//
// TTY utils
//
bool tty_can_use_colors() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
//
// Model utils
//
static inline void common_init_sampler_from_model(
const llama_model * model,
common_params_sampling & sparams) {
const uint64_t config = sparams.user_sampling_config;
auto get_int32 = [&](const char * key, int32_t & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[64] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
int32_t v = strtol(buf, &end, 10);
if (end && end != buf) dst = v;
}
};
auto get_float = [&](const char * key, float & dst, uint64_t user_config) {
if (config & user_config) return;
char buf[128] = {0};
if (llama_model_meta_val_str(model, key, buf, sizeof(buf)) > 0) {
char * end = nullptr;
float v = strtof(buf, &end);
if (end && end != buf) dst = v;
}
};
// Sampling sequence
if (!(config & common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS)) {
char buf[512] = {0};
if (llama_model_meta_val_str(model, llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE), buf, sizeof(buf)) > 0) {
const std::vector<std::string> sampler_names = string_split<std::string>(std::string(buf), ';');
if (!sampler_names.empty()) {
sparams.samplers = common_sampler_types_from_names(sampler_names, true);
}
}
}
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_K), sparams.top_k, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_K);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TOP_P), sparams.top_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TOP_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIN_P), sparams.min_p, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIN_P);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY), sparams.xtc_probability, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD), sparams.xtc_threshold, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_TEMP), sparams.temp, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_TEMP);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N), sparams.penalty_last_n, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT), sparams.penalty_repeat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT);
get_int32(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT), sparams.mirostat, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU), sparams.mirostat_tau, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU);
get_float(llama_model_meta_key_str(LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA), sparams.mirostat_eta, common_params_sampling_config::COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA);
}
struct common_init_result common_init_from_params(common_params & params) {
common_init_result iparams;
auto mparams = common_model_params_to_llama(params);
@@ -1076,8 +924,6 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
common_init_sampler_from_model(model, params.sampling);
const llama_vocab * vocab = llama_model_get_vocab(model);
auto cparams = common_context_params_to_llama(params);

View File

@@ -2,19 +2,17 @@
#pragma once
#include "ggml-opt.h"
#include "llama-cpp.h"
#include <set>
#include <sstream>
#include <string>
#include <string_view>
#include <vector>
#include <map>
#include <sstream>
#include <cmath>
#if defined(_WIN32) && !defined(_WIN32_WINNT)
#define _WIN32_WINNT 0x0A00
#endif
#include "ggml-opt.h"
#include "llama-cpp.h"
#ifdef _WIN32
#define DIRECTORY_SEPARATOR '\\'
@@ -30,14 +28,7 @@
fprintf(stderr, "%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET); \
} while(0)
struct common_time_meas {
common_time_meas(int64_t & t_acc, bool disable = false);
~common_time_meas();
const int64_t t_start_us;
int64_t & t_acc;
};
#define DEFAULT_MODEL_PATH "models/7B/ggml-model-f16.gguf"
struct common_adapter_lora_info {
std::string path;
@@ -142,22 +133,6 @@ struct common_grammar_trigger {
llama_token token = LLAMA_TOKEN_NULL;
};
enum common_params_sampling_config : uint64_t {
COMMON_PARAMS_SAMPLING_CONFIG_SAMPLERS = 1 << 0,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_K = 1 << 1,
COMMON_PARAMS_SAMPLING_CONFIG_TOP_P = 1 << 2,
COMMON_PARAMS_SAMPLING_CONFIG_MIN_P = 1 << 3,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_PROBABILITY = 1 << 4,
COMMON_PARAMS_SAMPLING_CONFIG_XTC_THRESHOLD = 1 << 5,
COMMON_PARAMS_SAMPLING_CONFIG_TEMP = 1 << 6,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_LAST_N = 1 << 7,
COMMON_PARAMS_SAMPLING_CONFIG_PENALTY_REPEAT = 1 << 8,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT = 1 << 9,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_TAU = 1 << 10,
COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11,
};
// sampling parameters
struct common_params_sampling {
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -190,8 +165,6 @@ struct common_params_sampling {
bool no_perf = false; // disable performance metrics
bool timing_per_token = false;
uint64_t user_sampling_config = 0; // bitfield to track user-specified samplers
std::vector<std::string> dry_sequence_breakers = {"\n", ":", "\"", "*"}; // default sequence breakers for DRY
@@ -225,7 +198,6 @@ struct common_params_model {
std::string hf_repo = ""; // HF repo // NOLINT
std::string hf_file = ""; // HF file // NOLINT
std::string docker_repo = ""; // Docker repo // NOLINT
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
};
struct common_params_speculative {
@@ -372,7 +344,7 @@ struct common_params {
std::vector<common_control_vector_load_info> control_vectors; // control vector with user defined scale
int32_t verbosity = 3; // LOG_LEVEL_INFO
int32_t verbosity = 0;
int32_t control_vector_layer_start = -1; // layer range for control vector
int32_t control_vector_layer_end = -1; // layer range for control vector
bool offline = false;
@@ -434,8 +406,6 @@ struct common_params {
bool mmproj_use_gpu = true; // use GPU for multimodal model
bool no_mmproj = false; // explicitly disable multimodal model
std::vector<std::string> image; // path to image file(s)
int image_min_tokens = -1;
int image_max_tokens = -1;
// finetune
struct lr_opt lr;
@@ -481,21 +451,14 @@ struct common_params {
bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false;
// router server configs
std::string models_dir = ""; // directory containing models for the router server
int models_max = 4; // maximum number of models to load simultaneously
bool models_autoload = true; // automatically load models when requested via the router server
bool log_json = false;
std::string slot_save_path;
std::string media_path; // path to directory for loading media files
float slot_prompt_similarity = 0.1f;
// batched-bench params
bool is_pp_shared = false;
bool is_tg_separate = false;
bool is_pp_shared = false;
std::vector<int32_t> n_pp;
std::vector<int32_t> n_tg;
@@ -542,10 +505,6 @@ struct common_params {
// return false from callback to abort model loading or true to continue
llama_progress_callback load_progress_callback = NULL;
void * load_progress_callback_user_data = NULL;
bool has_speculative() const {
return !speculative.model.path.empty() || !speculative.model.hf_repo.empty();
}
};
// call once at the start of a program if it uses libcommon
@@ -640,28 +599,12 @@ std::string string_from(const struct llama_context * ctx, const struct llama_bat
// Filesystem utils
//
bool fs_validate_filename(const std::string & filename, bool allow_subdirs = false);
bool fs_validate_filename(const std::string & filename);
bool fs_create_directory_with_parents(const std::string & path);
bool fs_is_directory(const std::string & path);
std::string fs_get_cache_directory();
std::string fs_get_cache_file(const std::string & filename);
struct common_file_info {
std::string path;
std::string name;
size_t size = 0; // in bytes
bool is_dir = false;
};
std::vector<common_file_info> fs_list(const std::string & path, bool include_directories);
//
// TTY utils
//
// Auto-detect if colors can be enabled based on terminal and environment
bool tty_can_use_colors();
//
// Model utils
//

View File

@@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};
@@ -303,8 +303,6 @@ static std::string format_literal(const std::string & literal) {
return "\"" + escaped + "\"";
}
std::string gbnf_format_literal(const std::string & literal) { return format_literal(literal); }
class SchemaConverter {
private:
friend std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options);
@@ -603,10 +601,7 @@ private:
}
std::string _resolve_ref(const std::string & ref) {
auto it = ref.find('#');
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
@@ -779,24 +774,11 @@ public:
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_object() && target.contains(sel)) {
target = target[sel];
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
if (sel_index >= target.size()) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel_index];
} else {
if (target.is_null() || !target.contains(sel)) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
@@ -974,7 +956,7 @@ public:
void check_errors() {
if (!_errors.empty()) {
throw std::invalid_argument("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
throw std::runtime_error("JSON schema conversion failed:\n" + string_join(_errors, "\n"));
}
if (!_warnings.empty()) {
fprintf(stderr, "WARNING: JSON schema conversion was incomplete: %s\n", string_join(_warnings, "; ").c_str());

View File

@@ -18,6 +18,4 @@ struct common_grammar_options {
bool dotall = false;
};
std::string gbnf_format_literal(const std::string & literal);
std::string build_grammar(const std::function<void(const common_grammar_builder &)> & cb, const common_grammar_options & options = {});

View File

@@ -1,4 +1,3 @@
#include "common.h"
#include "log.h"
#include <chrono>
@@ -27,6 +26,30 @@ void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
@@ -368,7 +391,7 @@ struct common_log * common_log_main() {
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(tty_can_use_colors());
log.set_colors(common_log_should_use_colors_auto());
});
return &log;
@@ -399,7 +422,7 @@ void common_log_set_file(struct common_log * log, const char * file) {
void common_log_set_colors(struct common_log * log, log_colors colors) {
if (colors == LOG_COLORS_AUTO) {
log->set_colors(tty_can_use_colors());
log->set_colors(common_log_should_use_colors_auto());
return;
}
@@ -419,23 +442,3 @@ void common_log_set_prefix(struct common_log * log, bool prefix) {
void common_log_set_timestamps(struct common_log * log, bool timestamps) {
log->set_timestamps(timestamps);
}
static int common_get_verbosity(enum ggml_log_level level) {
switch (level) {
case GGML_LOG_LEVEL_DEBUG: return LOG_LEVEL_DEBUG;
case GGML_LOG_LEVEL_INFO: return LOG_LEVEL_INFO;
case GGML_LOG_LEVEL_WARN: return LOG_LEVEL_WARN;
case GGML_LOG_LEVEL_ERROR: return LOG_LEVEL_ERROR;
case GGML_LOG_LEVEL_CONT: return LOG_LEVEL_INFO; // same as INFO
case GGML_LOG_LEVEL_NONE:
default:
return LOG_LEVEL_OUTPUT;
}
}
void common_log_default_callback(enum ggml_log_level level, const char * text, void * /*user_data*/) {
auto verbosity = common_get_verbosity(level);
if (verbosity <= common_log_verbosity_thold) {
common_log_add(common_log_main(), level, "%s", text);
}
}

View File

@@ -21,14 +21,8 @@
# define LOG_ATTRIBUTE_FORMAT(...) __attribute__((format(printf, __VA_ARGS__)))
#endif
#define LOG_LEVEL_DEBUG 4
#define LOG_LEVEL_INFO 3
#define LOG_LEVEL_WARN 2
#define LOG_LEVEL_ERROR 1
#define LOG_LEVEL_OUTPUT 0 // output data from tools
#define LOG_DEFAULT_DEBUG LOG_LEVEL_DEBUG
#define LOG_DEFAULT_LLAMA LOG_LEVEL_INFO
#define LOG_DEFAULT_DEBUG 1
#define LOG_DEFAULT_LLAMA 0
enum log_colors {
LOG_COLORS_AUTO = -1,
@@ -42,8 +36,6 @@ extern int common_log_verbosity_thold;
void common_log_set_verbosity_thold(int verbosity); // not thread-safe
void common_log_default_callback(enum ggml_log_level level, const char * text, void * user_data);
// the common_log uses an internal worker thread to print/write log messages
// when the worker thread is paused, incoming log messages are discarded
struct common_log;
@@ -73,11 +65,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
// 0.00.090.578 I llm_load_tensors: offloading 32 repeating layers to GPU
// 0.00.090.579 I llm_load_tensors: offloading non-repeating layers to GPU
//
// I - info (stdout, V = 0)
// W - warning (stderr, V = 0)
// E - error (stderr, V = 0)
// D - debug (stderr, V = LOG_DEFAULT_DEBUG)
// I - info (stdout, V = LOG_DEFAULT_INFO)
// W - warning (stderr, V = LOG_DEFAULT_WARN)
// E - error (stderr, V = LOG_DEFAULT_ERROR)
// O - output (stdout, V = LOG_DEFAULT_OUTPUT)
//
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
@@ -102,14 +93,14 @@ void common_log_set_timestamps(struct common_log * log, bool timestamps); // w
} \
} while (0)
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, LOG_LEVEL_OUTPUT, __VA_ARGS__)
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
#define LOG(...) LOG_TMPL(GGML_LOG_LEVEL_NONE, 0, __VA_ARGS__)
#define LOGV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_NONE, verbosity, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_LEVEL_DEBUG, __VA_ARGS__)
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, LOG_LEVEL_INFO, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, LOG_LEVEL_WARN, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, LOG_LEVEL_ERROR, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, LOG_LEVEL_INFO, __VA_ARGS__) // same as INFO
#define LOG_INF(...) LOG_TMPL(GGML_LOG_LEVEL_INFO, 0, __VA_ARGS__)
#define LOG_WRN(...) LOG_TMPL(GGML_LOG_LEVEL_WARN, 0, __VA_ARGS__)
#define LOG_ERR(...) LOG_TMPL(GGML_LOG_LEVEL_ERROR, 0, __VA_ARGS__)
#define LOG_DBG(...) LOG_TMPL(GGML_LOG_LEVEL_DEBUG, LOG_DEFAULT_DEBUG, __VA_ARGS__)
#define LOG_CNT(...) LOG_TMPL(GGML_LOG_LEVEL_CONT, 0, __VA_ARGS__)
#define LOG_INFV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_INFO, verbosity, __VA_ARGS__)
#define LOG_WRNV(verbosity, ...) LOG_TMPL(GGML_LOG_LEVEL_WARN, verbosity, __VA_ARGS__)

View File

@@ -3,10 +3,9 @@
#include "common.h"
#include "log.h"
#include <algorithm>
#include <cmath>
#include <cstring>
#include <unordered_map>
#include <algorithm>
// the ring buffer works similarly to std::deque, but with a fixed capacity
// TODO: deduplicate with llama-impl.h
@@ -113,13 +112,6 @@ struct common_sampler {
llama_token_data_array cur_p;
void reset() {
prev.clear();
llama_sampler_reset(grmr);
llama_sampler_reset(chain);
}
void set_logits(struct llama_context * ctx, int idx) {
const auto * logits = llama_get_logits_ith(ctx, idx);
@@ -136,12 +128,6 @@ struct common_sampler {
cur_p = { cur.data(), cur.size(), -1, false };
}
common_time_meas tm() {
return common_time_meas(t_total_us, params.no_perf);
}
mutable int64_t t_total_us = 0;
};
std::string common_params_sampling::print() const {
@@ -312,8 +298,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
}
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
const auto tm = gsmpl->tm();
if (accept_grammar) {
llama_sampler_accept(gsmpl->grmr, token);
}
@@ -324,7 +308,9 @@ void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, boo
}
void common_sampler_reset(struct common_sampler * gsmpl) {
gsmpl->reset();
llama_sampler_reset(gsmpl->grmr);
llama_sampler_reset(gsmpl->chain);
}
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
@@ -341,54 +327,16 @@ struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl) {
// TODO: measure grammar performance
const double t_sampling_ms = gsmpl ? 1e-3*gsmpl->t_total_us : 0;
llama_perf_sampler_data data_smpl;
llama_perf_context_data data_ctx;
memset(&data_smpl, 0, sizeof(data_smpl));
memset(&data_ctx, 0, sizeof(data_ctx));
if (gsmpl) {
auto & data = data_smpl;
data = llama_perf_sampler(gsmpl->chain);
// note: the sampling time includes the samplers time + extra time spent in common/sampling
LOG_INF("%s: sampling time = %10.2f ms\n", __func__, t_sampling_ms);
LOG_INF("%s: samplers time = %10.2f ms / %5d tokens\n", __func__, data.t_sample_ms, data.n_sample);
llama_perf_sampler_print(gsmpl->chain);
}
if (ctx) {
auto & data = data_ctx;
data = llama_perf_context(ctx);
const double t_end_ms = 1e-3 * ggml_time_us();
const double t_total_ms = t_end_ms - data.t_start_ms;
const double t_unacc_ms = t_total_ms - (t_sampling_ms + data.t_p_eval_ms + data.t_eval_ms);
const double t_unacc_pc = 100.0 * t_unacc_ms / t_total_ms;
LOG_INF("%s: load time = %10.2f ms\n", __func__, data.t_load_ms);
LOG_INF("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_p_eval_ms, data.n_p_eval, data.t_p_eval_ms / data.n_p_eval, 1e3 / data.t_p_eval_ms * data.n_p_eval);
LOG_INF("%s: eval time = %10.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, data.t_eval_ms, data.n_eval, data.t_eval_ms / data.n_eval, 1e3 / data.t_eval_ms * data.n_eval);
LOG_INF("%s: total time = %10.2f ms / %5d tokens\n", __func__, (t_end_ms - data.t_start_ms), (data.n_p_eval + data.n_eval));
LOG_INF("%s: unaccounted time = %10.2f ms / %5.1f %% (total - sampling - prompt eval - eval) / (total)\n", __func__, t_unacc_ms, t_unacc_pc);
LOG_INF("%s: graphs reused = %10d\n", __func__, data.n_reused);
llama_perf_context_print(ctx);
llama_memory_breakdown_print(ctx);
}
}
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
llama_synchronize(ctx);
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
const auto tm = gsmpl->tm();
gsmpl->set_logits(ctx, idx);
auto & grmr = gsmpl->grmr;
@@ -480,8 +428,6 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
const auto tm = gsmpl->tm();
auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {

View File

@@ -83,7 +83,6 @@ extern "C" {
LLAMA_ROPE_TYPE_NORM = 0,
LLAMA_ROPE_TYPE_NEOX = GGML_ROPE_TYPE_NEOX,
LLAMA_ROPE_TYPE_MROPE = GGML_ROPE_TYPE_MROPE,
LLAMA_ROPE_TYPE_IMROPE = GGML_ROPE_TYPE_IMROPE,
LLAMA_ROPE_TYPE_VISION = GGML_ROPE_TYPE_VISION,
};
@@ -246,21 +245,6 @@ extern "C" {
LLAMA_KV_OVERRIDE_TYPE_STR,
};
enum llama_model_meta_key {
LLAMA_MODEL_META_KEY_SAMPLING_SEQUENCE,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_K,
LLAMA_MODEL_META_KEY_SAMPLING_TOP_P,
LLAMA_MODEL_META_KEY_SAMPLING_MIN_P,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_PROBABILITY,
LLAMA_MODEL_META_KEY_SAMPLING_XTC_THRESHOLD,
LLAMA_MODEL_META_KEY_SAMPLING_TEMP,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_LAST_N,
LLAMA_MODEL_META_KEY_SAMPLING_PENALTY_REPEAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_TAU,
LLAMA_MODEL_META_KEY_SAMPLING_MIROSTAT_ETA,
};
struct llama_model_kv_override {
enum llama_model_kv_override_type tag;
@@ -476,11 +460,7 @@ extern "C" {
LLAMA_API bool llama_supports_gpu_offload(void);
LLAMA_API bool llama_supports_rpc (void);
// NOTE: After creating a llama_context, it is recommended to query the actual values using these functions
// In some cases the requested values via llama_context_params may differ from the actual values used by the context
// ref: https://github.com/ggml-org/llama.cpp/pull/17046#discussion_r2503085732
LLAMA_API uint32_t llama_n_ctx (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ctx_seq (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_batch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_ubatch (const struct llama_context * ctx);
LLAMA_API uint32_t llama_n_seq_max (const struct llama_context * ctx);
@@ -501,7 +481,6 @@ extern "C" {
LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head (const struct llama_model * model);
LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model);
@@ -533,9 +512,6 @@ extern "C" {
// Get the number of metadata key/value pairs
LLAMA_API int32_t llama_model_meta_count(const struct llama_model * model);
// Get sampling metadata key name. Returns nullptr if the key is invalid
LLAMA_API const char * llama_model_meta_key_str(enum llama_model_meta_key key);
// Get metadata key name by index
LLAMA_API int32_t llama_model_meta_key_by_index(const struct llama_model * model, int32_t i, char * buf, size_t buf_size);
@@ -608,7 +584,7 @@ extern "C" {
LLAMA_API int32_t llama_adapter_meta_val_str_by_index(const struct llama_adapter_lora * adapter, int32_t i, char * buf, size_t buf_size);
// Manually free a LoRA adapter
// NOTE: loaded adapters will be free when the associated model is deleted
// Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
@@ -1134,6 +1110,8 @@ extern "C" {
// // sample from the logits of the last token in the batch
// const llama_token id = llama_sampler_sample(smpl, ctx, -1);
//
// // accepting the token updates the internal state of certain samplers (e.g. grammar, repetition, etc.)
// llama_sampler_accept(smpl, id);
// ...
// }
//

View File

@@ -32,9 +32,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
{ LLM_ARCH_PHI3, "phi3" },
{ LLM_ARCH_PHIMOE, "phimoe" },
@@ -92,7 +89,6 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_BAILINGMOE2, "bailingmoe2" },
{ LLM_ARCH_DOTS1, "dots1" },
{ LLM_ARCH_ARCEE, "arcee" },
{ LLM_ARCH_AFMOE, "afmoe" },
{ LLM_ARCH_ERNIE4_5, "ernie4_5" },
{ LLM_ARCH_ERNIE4_5_MOE, "ernie4_5-moe" },
{ LLM_ARCH_HUNYUAN_MOE, "hunyuan-moe" },
@@ -108,40 +104,23 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_SEED_OSS, "seed_oss" },
{ LLM_ARCH_GROVEMOE, "grovemoe" },
{ LLM_ARCH_APERTUS, "apertus" },
{ LLM_ARCH_MINIMAX_M2, "minimax-m2" },
{ LLM_ARCH_COGVLM, "cogvlm" },
{ LLM_ARCH_RND1, "rnd1" },
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
{ LLM_ARCH_MISTRAL3, "mistral3" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_GENERAL_TYPE, "general.type" },
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_SAMPLING_SEQUENCE, "general.sampling.sequence" },
{ LLM_KV_GENERAL_SAMPLING_TOP_K, "general.sampling.top_k" },
{ LLM_KV_GENERAL_SAMPLING_TOP_P, "general.sampling.top_p" },
{ LLM_KV_GENERAL_SAMPLING_MIN_P, "general.sampling.min_p" },
{ LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY, "general.sampling.xtc_probability" },
{ LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD, "general.sampling.xtc_threshold" },
{ LLM_KV_GENERAL_SAMPLING_TEMP, "general.sampling.temp" },
{ LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N, "general.sampling.penalty_last_n" },
{ LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT, "general.sampling.penalty_repeat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT, "general.sampling.mirostat" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU, "general.sampling.mirostat_tau" },
{ LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA, "general.sampling.mirostat_eta" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
{ LLM_KV_GENERAL_URL, "general.url" },
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" },
{ LLM_KV_GENERAL_LICENSE, "general.license" },
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
{ LLM_KV_GENERAL_TYPE, "general.type" },
{ LLM_KV_GENERAL_ARCHITECTURE, "general.architecture" },
{ LLM_KV_GENERAL_QUANTIZATION_VERSION, "general.quantization_version" },
{ LLM_KV_GENERAL_ALIGNMENT, "general.alignment" },
{ LLM_KV_GENERAL_FILE_TYPE, "general.file_type" },
{ LLM_KV_GENERAL_NAME, "general.name" },
{ LLM_KV_GENERAL_AUTHOR, "general.author" },
{ LLM_KV_GENERAL_VERSION, "general.version" },
{ LLM_KV_GENERAL_URL, "general.url" },
{ LLM_KV_GENERAL_DESCRIPTION, "general.description" },
{ LLM_KV_GENERAL_LICENSE, "general.license" },
{ LLM_KV_GENERAL_SOURCE_URL, "general.source.url" },
{ LLM_KV_GENERAL_SOURCE_HF_REPO, "general.source.huggingface.repository" },
{ LLM_KV_VOCAB_SIZE, "%s.vocab_size" },
{ LLM_KV_CONTEXT_LENGTH, "%s.context_length" },
@@ -167,7 +146,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
{ LLM_KV_NUM_DEEPSTACK_LAYERS, "%s.n_deepstack_layers" },
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
{ LLM_KV_LOGIT_SCALE, "%s.logit_scale" },
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
@@ -206,7 +184,6 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
{ LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION, "%s.attention.block_skip_connection" },
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
@@ -352,36 +329,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_AFMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_GATE, "blk.%d.attn_gate" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_LLAMA4,
{
@@ -834,77 +781,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_QWEN3VL,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_QWEN3VLMOE,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_PHI2,
{
@@ -2292,7 +2168,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
}
},
@@ -2314,7 +2190,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
@@ -2456,110 +2332,6 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
},
},
{
LLM_ARCH_MINIMAX_M2,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" },
},
},
{
LLM_ARCH_PANGU_EMBED,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_COGVLM,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_VISEXP_ATTN_QKV, "blk.%d.vis_attn_qkv" },
{ LLM_TENSOR_VISEXP_ATTN_OUT, "blk.%d.vis_attn_output" },
{ LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" },
{ LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" },
{ LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" },
},
},
{
LLM_ARCH_RND1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_MISTRAL3,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_UNKNOWN,
{
@@ -2568,21 +2340,11 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
};
// declare information about the model weight tensors:
// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight
// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator
//
// for example, input layers are usually assigned to CPU/host buffer types
//
// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal
// assignment of the buffer types and extra overhead during computation
// example: https://github.com/ggml-org/llama.cpp/pull/17548
//
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@@ -2599,7 +2361,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_ATTN_V, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_ATTN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -2637,7 +2398,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@@ -2659,7 +2419,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_FFN_ACT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_DIV}},
{LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}},
{LLM_TENSOR_SSM_A_NOSCAN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // a version of SSM_A used for MUL instead of SSM_SCAN
{LLM_TENSOR_SSM_DT_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_B_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
{LLM_TENSOR_SSM_C_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}},
@@ -2750,11 +2509,6 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SHORTCONV_CONV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}},
{LLM_TENSOR_SHORTCONV_INPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SHORTCONV_OUTPROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_ATTN_QKV, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_ATTN_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
// NextN/MTP tensors are currently ignored (reserved for future MTP support)
// These tensors only exist in the last layer(s) and are treated as output tensors
{LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@@ -2838,7 +2592,6 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_QWEN3NEXT:
return true;
default:
return false;
@@ -2850,7 +2603,6 @@ bool llm_arch_is_diffusion(const llm_arch & arch) {
case LLM_ARCH_DREAM:
case LLM_ARCH_LLADA:
case LLM_ARCH_LLADA_MOE:
case LLM_ARCH_RND1:
return true;
default:
return false;

View File

@@ -36,9 +36,6 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
LLM_ARCH_PHI3,
LLM_ARCH_PHIMOE,
@@ -96,7 +93,6 @@ enum llm_arch {
LLM_ARCH_BAILINGMOE2,
LLM_ARCH_DOTS1,
LLM_ARCH_ARCEE,
LLM_ARCH_AFMOE,
LLM_ARCH_ERNIE4_5,
LLM_ARCH_ERNIE4_5_MOE,
LLM_ARCH_HUNYUAN_MOE,
@@ -112,11 +108,6 @@ enum llm_arch {
LLM_ARCH_SEED_OSS,
LLM_ARCH_GROVEMOE,
LLM_ARCH_APERTUS,
LLM_ARCH_MINIMAX_M2,
LLM_ARCH_COGVLM,
LLM_ARCH_RND1,
LLM_ARCH_PANGU_EMBED,
LLM_ARCH_MISTRAL3,
LLM_ARCH_UNKNOWN,
};
@@ -126,18 +117,6 @@ enum llm_kv {
LLM_KV_GENERAL_QUANTIZATION_VERSION,
LLM_KV_GENERAL_ALIGNMENT,
LLM_KV_GENERAL_FILE_TYPE,
LLM_KV_GENERAL_SAMPLING_SEQUENCE,
LLM_KV_GENERAL_SAMPLING_TOP_K,
LLM_KV_GENERAL_SAMPLING_TOP_P,
LLM_KV_GENERAL_SAMPLING_MIN_P,
LLM_KV_GENERAL_SAMPLING_XTC_PROBABILITY,
LLM_KV_GENERAL_SAMPLING_XTC_THRESHOLD,
LLM_KV_GENERAL_SAMPLING_TEMP,
LLM_KV_GENERAL_SAMPLING_PENALTY_LAST_N,
LLM_KV_GENERAL_SAMPLING_PENALTY_REPEAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_TAU,
LLM_KV_GENERAL_SAMPLING_MIROSTAT_ETA,
LLM_KV_GENERAL_NAME,
LLM_KV_GENERAL_AUTHOR,
LLM_KV_GENERAL_VERSION,
@@ -171,7 +150,6 @@ enum llm_kv {
LLM_KV_EXPERTS_PER_GROUP,
LLM_KV_MOE_EVERY_N_LAYERS,
LLM_KV_NEXTN_PREDICT_LAYERS,
LLM_KV_NUM_DEEPSTACK_LAYERS,
LLM_KV_POOLING_TYPE,
LLM_KV_LOGIT_SCALE,
LLM_KV_DECODER_START_TOKEN_ID,
@@ -210,7 +188,6 @@ enum llm_kv {
LLM_KV_ATTENTION_SCALE,
LLM_KV_ATTENTION_OUTPUT_SCALE,
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
LLM_KV_ATTENTION_BLOCK_SKIP_CONNECTION,
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
@@ -331,7 +308,6 @@ enum llm_tensor {
LLM_TENSOR_ATTN_POST_NORM,
LLM_TENSOR_ATTN_ROT_EMBD,
LLM_TENSOR_ATTN_SINKS,
LLM_TENSOR_ATTN_GATE,
LLM_TENSOR_FFN_GATE_INP,
LLM_TENSOR_FFN_GATE_INP_SHEXP,
LLM_TENSOR_FFN_NORM,
@@ -381,13 +357,11 @@ enum llm_tensor {
LLM_TENSOR_SSM_DT,
LLM_TENSOR_SSM_DT_NORM,
LLM_TENSOR_SSM_A,
LLM_TENSOR_SSM_A_NOSCAN, // qwen3next special case with MUL instead of SSM_SCAN
LLM_TENSOR_SSM_B_NORM,
LLM_TENSOR_SSM_C_NORM,
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,
@@ -484,11 +458,6 @@ enum llm_tensor {
LLM_TENSOR_SHORTCONV_CONV,
LLM_TENSOR_SHORTCONV_INPROJ,
LLM_TENSOR_SHORTCONV_OUTPROJ,
LLM_TENSOR_VISEXP_ATTN_QKV,
LLM_TENSOR_VISEXP_ATTN_OUT,
LLM_TENSOR_VISEXP_FFN_GATE,
LLM_TENSOR_VISEXP_FFN_DOWN,
LLM_TENSOR_VISEXP_FFN_UP,
LLM_TENSOR_NEXTN_EH_PROJ,
LLM_TENSOR_NEXTN_EMBED_TOKENS,
LLM_TENSOR_NEXTN_ENORM,

View File

@@ -215,7 +215,6 @@ bool llama_batch_allocr::init(
/*.n_seq_tokens =*/ (uint32_t) 1,
/*.n_seqs =*/ (uint32_t) batch.n_tokens,
/*.n_seqs_unq =*/ (uint32_t) this->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token,
/*.embd =*/ batch.embd,
/*.pos =*/ batch.pos,
@@ -252,72 +251,46 @@ bool llama_batch_allocr::init(
// consistency checks
//
if (n_pos_per_embd > 1) {
// M-RoPE case: allow position to "jump" forward only (non-continuous positions are allowed)
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (batch.token) {
if (p0 >= 0 && p0 >= seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X < Y\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
} else {
// embedding inputs can have overlapping positions
if (p0 >= 0 && p0 > seq_pos_min(s)) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" for M-RoPE, it is required that the position satisfies: X <= Y\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
}
} else {
for (uint32_t s = 0; s < n_seq_max; ++s) {
if (seq_pos[s].empty()) {
continue;
}
const llama_pos p0 = memory ? memory->seq_pos_max(s) : -1;
if (p0 >= 0) {
bool ok = true;
if (seq_pos_min(s) != p0 + 1) {
ok = false;
}
} else {
assert(batch.embd);
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
// for embeddings (typically used as vision input), we allow them to have repeating positions
// ref: https://github.com/ggml-org/llama.cpp/issues/13694#issuecomment-2983871762
if (seq_pos_min(s) != p0 && seq_pos_min(s) != p0 + 1) {
ok = false;
}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
if (!ok) {
LLAMA_LOG_ERROR(
"%s: the tokens of sequence %d in the input batch have inconsistent sequence positions:\n"
" - the last position stored in the memory module of the context (i.e. the KV cache) for sequence %d is X = %d\n"
" - the tokens for sequence %d in the input batch have a starting position of Y = %d\n"
" it is required that the sequence positions remain consecutive: Y = X + 1\n",
__func__, s, s, p0, s, seq_pos_min(s));
return false;
}
}
if (seq_pos_max(s) - seq_pos_min(s) + 1 > (int) seq_pos[s].size()) {
LLAMA_LOG_ERROR("%s: sequence %d positions are not continuous\n", __func__, s);
return false;
}
}
if (memory) {
@@ -416,7 +389,6 @@ llama_ubatch llama_batch_allocr::ubatch_reserve(uint32_t n_seq_tokens, uint32_t
/*.n_seq_tokens =*/ n_seq_tokens,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ n_seqs,
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ udata->token.data(),
/*.embd =*/ nullptr,
@@ -683,8 +655,10 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
auto udata = std::make_shared<llama_ubatch::data_t>();
const int32_t n_pos_cur = batch.embd ? n_pos_per_embd : 1;
const int64_t n_embd_all = batch.embd ? (int64_t) n_tokens*n_embd : 0;
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_per_embd;
const int64_t n_pos_all = (int64_t) n_tokens*n_pos_cur;
udata->token .resize(n_tokens);
udata->embd .resize(n_embd_all);
@@ -706,13 +680,8 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
memcpy(udata->embd.data() + i*n_embd, batch.embd + (int64_t) idxs[i]*n_embd, n_embd*sizeof(float));
}
for (size_t j = 0; j < (size_t)n_pos_per_embd; ++j) {
// if we are using M-RoPE
// if the current batch is text, we need to broadcast the same position across all RoPE sections
// otherwise, the input batch is image embeddings, we copy the positions as-is
// if we are not using M-RoPE, there is only one position per token (this loop runs only once)
size_t src_off = batch.token ? 0 : j*batch.n_tokens;
udata->pos[j*n_tokens + i] = batch.pos[src_off + idxs[i]];
for (int j = 0; j < n_pos_cur; ++j) {
udata->pos[j*n_tokens + i] = batch.pos[j*batch.n_tokens + idxs[i]];
}
udata->n_seq_id[i] = batch.n_seq_id[idxs[i]];
@@ -741,7 +710,6 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
/*.n_seq_tokens =*/ n_tokens/n_seqs,
/*.n_seqs =*/ n_seqs,
/*.n_seqs_unq =*/ (uint32_t) udata->seq_id_unq.size(),
/*.n_pos =*/ n_pos_per_embd,
/*.token =*/ batch.token ? udata->token.data() : nullptr,
/*.embd =*/ batch.embd ? udata->embd.data() : nullptr,

View File

@@ -17,16 +17,6 @@ struct llama_ubatch {
return b_equal_seqs != 0;
}
// typical for M-RoPE cases:
// 0 - sequantial position of the tokens/embeddings in the sequence
// 1 - y position in the image
// 2 - x position in the image
// 3 - other
bool is_pos_2d() const {
// TODO @ngxson : we may need to check for model arch when more models use >1 positions
return n_pos >= 3;
}
uint32_t b_equal_seqs; // note: this is a boolean, but we use an int32_t for alignment
// otherwise address sanitizer complains
// TODO: whole_seqs for embeddings?
@@ -35,7 +25,6 @@ struct llama_ubatch {
uint32_t n_seq_tokens; // tokens per sequence set
uint32_t n_seqs; // sequence sets in the ubatch
uint32_t n_seqs_unq; // unique sequence ids in the ubatch
uint32_t n_pos; // number of position inputs for each token/embedding
// seq_id_unq: unique sequence ids in the ubatch
// seq_idx: indices of the unique sequence ids in the ubatch in [0, n_seqs_unq)
@@ -44,7 +33,7 @@ struct llama_ubatch {
// // size | idx | val
llama_token * token; // [n_tokens] | i | id, token
float * embd; // [n_embd, n_tokens] | i | embd
llama_pos * pos; // [n_tokens*n_pos] | i | pos
llama_pos * pos; // [n_tokens] | i | pos
int32_t * n_seq_id; // [n_tokens] | i | -
llama_seq_id ** seq_id; // [n_tokens] | s | s0, s1, seq_id
llama_seq_id * seq_id_unq; // [n_seqs_unq] | s | seq_id

View File

@@ -73,7 +73,6 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
{ "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED },
};
llm_chat_template llm_chat_template_from_str(const std::string & name) {
@@ -214,8 +213,6 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
return LLM_CHAT_TEMPLATE_SEED_OSS;
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
return LLM_CHAT_TEMPLATE_GROK_2;
} else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) {
return LLM_CHAT_TEMPLATE_PANGU_EMBED;
}
return LLM_CHAT_TEMPLATE_UNKNOWN;
}
@@ -816,35 +813,6 @@ int32_t llm_chat_apply_template(
if (add_ass) {
ss << "Assistant:";
}
}else if (tmpl == LLM_CHAT_TEMPLATE_PANGU_EMBED) {
// [unused9]系统xxx[unused10]
// [unused9]用户xxx[unused10]
// [unused9]助手xxx[unused10]
// ...
for (size_t i = 0; i < chat.size(); ++i) {
const auto & msg = chat[i];
const std::string & role = msg->role;
const std::string & content = msg->content;
if (i == 0 && role != "system") {
ss << "[unused9]系统:[unused10]";
}
if (role == "system") {
ss << "[unused9]系统:" << content << "[unused10]";
} else if (role == "user") {
ss << "[unused9]用户:" << content << "[unused10]";
} else if (role == "assistant") {
ss << "[unused9]助手:" << content << "[unused10]";
} else if (role == "tool") {
ss << "[unused9]工具:" << content << "[unused10]";
} else if (role == "function") {
ss << "[unused9]方法:" << content << "[unused10]";
}
}
if (add_ass) {
ss << "[unused9]助手:";
}
} else {
// template not supported
return -1;

Some files were not shown because too many files have changed in this diff Show More